bijx.DiffraxConfig¶
- class bijx.DiffraxConfig[source]¶
Bases:
PytreeConfiguration for diffrax ODE solving in continuous normalizing flows.
Encapsulates all parameters needed for diffrax-based ODE integration, including solver choice, step size control, and adjoint method selection. Provides convenient parameter override functionality for runtime configuration.
- Parameters:
solver (
AbstractSolver) – Diffrax solver instance (default: Tsit5 adaptive solver).t_start (
float) – Integration start time.t_end (
float) – Integration end time.dt (
float) – (Initial) Step size for integration.saveat (
SaveAt) – Configuration for which time points to save.stepsize_controller (
AbstractStepSizeController) – Strategy for adaptive step size control.adjoint (
AbstractAdjoint) – Adjoint method for gradient computation.event (
Event|None) – Optional event detection during integration.max_steps (
int|None) – Maximum number of integration steps allowed.throw (
bool) – Whether to raise exceptions on integration failure.solver_state (
Any|None) – Initial solver internal state.controller_state (
Any|None) – Initial step size controller state.made_jump (
bool|None) – Whether the solver has made a discontinuous jump.
Note
For more information, see [diffrax’s documentation](https://docs.kidger.site/diffrax/).
Example
>>> config = DiffraxConfig( ... solver=diffrax.Dopri5(), ... dt=0.1, ... adjoint=diffrax.RecursiveCheckpointAdjoint() ... )
- __init__(solver=Tsit5(), t_start=0.0, t_end=1.0, dt=0.05, saveat=SaveAt(subs=SubSaveAt(t1=True)), stepsize_controller=ConstantStepSize(), adjoint=RecursiveCheckpointAdjoint(), event=None, max_steps=4096, throw=True, solver_state=None, controller_state=None, made_jump=None)¶
- Parameters:
solver (AbstractSolver)
t_start (float)
t_end (float)
dt (float)
saveat (SaveAt)
stepsize_controller (AbstractStepSizeController)
adjoint (AbstractAdjoint)
event (Event | None)
max_steps (int | None)
throw (bool)
solver_state (Any | None)
controller_state (Any | None)
made_jump (bool | None)
- Return type:
None
Methods
optional_override(*[, t_start, t_end, dt, ...])Create new config with optionally overridden parameters.
replace(**changes)Create new config with specified parameters replaced.
solve(terms, y0, args)Solve ODE using configured diffrax solver.
Attributes
-
solver:
AbstractSolver= Tsit5()¶
-
t_start:
float= 0.0¶
-
t_end:
float= 1.0¶
-
dt:
float= 0.05¶
-
saveat:
SaveAt= SaveAt(subs=SubSaveAt(t1=True))¶
-
stepsize_controller:
AbstractStepSizeController= ConstantStepSize()¶
-
adjoint:
AbstractAdjoint= RecursiveCheckpointAdjoint()¶
-
event:
Event|None= None¶
-
max_steps:
int|None= 4096¶
-
throw:
bool= True¶
-
solver_state:
Any|None= None¶
-
controller_state:
Any|None= None¶
-
made_jump:
bool|None= None¶
- optional_override(*, t_start=None, t_end=None, dt=None, saveat=None, solver_state=None, controller_state=None)[source]¶
Create new config with optionally overridden parameters.
- Parameters:
t_start (
float|None) – Override integration start time.t_end (
float|None) – Override integration end time.dt (
float|None) – Override step size.saveat (
SaveAt|None) – Override save configuration.solver_state (
Any|None) – Override solver internal state.controller_state (
Any|None) – Override step size controller state.
- Returns:
New DiffraxConfig with specified parameters overridden.