bijx.DiffraxConfig¶
- class bijx.DiffraxConfig[source]¶
Bases:
object
Configuration for diffrax ODE solving in continuous normalizing flows.
Encapsulates all parameters needed for diffrax-based ODE integration, including solver choice, step size control, and adjoint method selection. Provides convenient parameter override functionality for runtime configuration.
- Parameters:
solver (
AbstractSolver
) – Diffrax solver instance (default: Tsit5 adaptive solver).t_start (
float
) – Integration start time.t_end (
float
) – Integration end time.dt (
float
) – (Initial) Step size for integration.saveat (
SaveAt
) – Configuration for which time points to save.stepsize_controller (
AbstractStepSizeController
) – Strategy for adaptive step size control.adjoint (
AbstractAdjoint
) – Adjoint method for gradient computation.event (
Event
|None
) – Optional event detection during integration.max_steps (
int
|None
) – Maximum number of integration steps allowed.throw (
bool
) – Whether to raise exceptions on integration failure.solver_state (
Any
|None
) – Initial solver internal state.controller_state (
Any
|None
) – Initial step size controller state.made_jump (
bool
|None
) – Whether the solver has made a discontinuous jump.
Note
For more information, see [diffrax’s documentation](https://docs.kidger.site/diffrax/).
Example
>>> config = DiffraxConfig( ... solver=diffrax.Dopri5(), ... dt=0.1, ... adjoint=diffrax.RecursiveCheckpointAdjoint() ... )
- __init__(solver=Tsit5(), t_start=0.0, t_end=1.0, dt=0.05, saveat=SaveAt(subs=SubSaveAt(t1=True)), stepsize_controller=ConstantStepSize(), adjoint=RecursiveCheckpointAdjoint(), event=None, max_steps=4096, throw=True, solver_state=None, controller_state=None, made_jump=None)¶
- Parameters:
solver (AbstractSolver)
t_start (float)
t_end (float)
dt (float)
saveat (SaveAt)
stepsize_controller (AbstractStepSizeController)
adjoint (AbstractAdjoint)
event (Event | None)
max_steps (int | None)
throw (bool)
solver_state (Any | None)
controller_state (Any | None)
made_jump (bool | None)
- Return type:
None
Methods
optional_override
(*[, t_start, t_end, dt, ...])Create new config with optionally overridden parameters.
replace
(**updates)Returns a new object replacing the specified fields with new values.
solve
(terms, y0, args)Solve ODE using configured diffrax solver.
Attributes
-
solver:
AbstractSolver
= Tsit5()¶
-
t_start:
float
= 0.0¶
-
t_end:
float
= 1.0¶
-
dt:
float
= 0.05¶
-
saveat:
SaveAt
= SaveAt(subs=SubSaveAt(t1=True))¶
-
stepsize_controller:
AbstractStepSizeController
= ConstantStepSize()¶
-
adjoint:
AbstractAdjoint
= RecursiveCheckpointAdjoint()¶
-
event:
Event
|None
= None¶
-
max_steps:
int
|None
= 4096¶
-
throw:
bool
= True¶
-
solver_state:
Any
|None
= None¶
-
controller_state:
Any
|None
= None¶
-
made_jump:
bool
|None
= None¶
- optional_override(*, t_start=None, t_end=None, dt=None, saveat=None, solver_state=None, controller_state=None)[source]¶
Create new config with optionally overridden parameters.
- Parameters:
t_start (
float
|None
) – Override integration start time.t_end (
float
|None
) – Override integration end time.dt (
float
|None
) – Override step size.saveat (
SaveAt
|None
) – Override save configuration.solver_state (
Any
|None
) – Override solver internal state.controller_state (
Any
|None
) – Override step size controller state.
- Returns:
New DiffraxConfig with specified parameters overridden.
- solve(terms, y0, args)[source]¶
Solve ODE using configured diffrax solver.
- Parameters:
terms – Diffrax ODE terms defining the vector field.
y0 – Initial condition.
args – Additional arguments passed to the vector field.
- Returns:
Diffrax solution object containing integration results.
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.