bijx.DiffraxConfig¶
- class bijx.DiffraxConfig[source]¶
Bases:
PytreeConfiguration for diffrax ODE solving in continuous normalizing flows.
Encapsulates all parameters needed for diffrax-based ODE integration, including solver choice, step size control, and adjoint method selection. Provides convenient parameter override functionality for runtime configuration.
- Parameters:
solver (
AbstractSolver) – Diffrax solver instance (default: Tsit5 adaptive solver).t_start (
float) – Integration start time.t_end (
float) – Integration end time.dt (
float) – (Initial) Step size for integration.saveat (
SaveAt) – Configuration for which time points to save.stepsize_controller (
AbstractStepSizeController) – Strategy for adaptive step size control.adjoint (
AbstractAdjoint) – Adjoint method for gradient computation.event (
Event|None) – Optional event detection during integration.max_steps (
int|None) – Maximum number of integration steps allowed.throw (
bool) – Whether to raise exceptions on integration failure.solver_state (
Any|None) – Initial solver internal state.controller_state (
Any|None) – Initial step size controller state.made_jump (
bool|None) – Whether the solver has made a discontinuous jump.
Note
For more information, see [diffrax’s documentation](https://docs.kidger.site/diffrax/).
Example
>>> config = DiffraxConfig( ... solver=diffrax.Dopri5(), ... dt=0.1, ... adjoint=diffrax.RecursiveCheckpointAdjoint() ... )
- __init__(solver=Tsit5(), t_start=0.0, t_end=1.0, dt=0.05, saveat=SaveAt(subs=SubSaveAt(t1=True)), stepsize_controller=ConstantStepSize(), adjoint=RecursiveCheckpointAdjoint(), event=None, max_steps=4096, throw=True, solver_state=None, controller_state=None, made_jump=None)¶
- Parameters:
solver (AbstractSolver)
t_start (float)
t_end (float)
dt (float)
saveat (SaveAt)
stepsize_controller (AbstractStepSizeController)
adjoint (AbstractAdjoint)
event (Event | None)
max_steps (int | None)
throw (bool)
solver_state (Any | None)
controller_state (Any | None)
made_jump (bool | None)
- Return type:
None
Methods
optional_override(*[, t_start, t_end, dt, ...])Create new config with optionally overridden parameters.
replace(**changes)Create new config with specified parameters replaced.
solve(terms, y0, args)Solve ODE using configured diffrax solver.
solve_sde(drift, diffusion, y0, rng[, args, ...])Solve SDE using configured parameters.
Attributes
- solver: AbstractSolver = Tsit5()¶
- t_start: float = 0.0¶
- t_end: float = 1.0¶
- dt: float = 0.05¶
- saveat: SaveAt = SaveAt(subs=SubSaveAt(t1=True))¶
- stepsize_controller: AbstractStepSizeController = ConstantStepSize()¶
- adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint()¶
- event: Event | None = None¶
- max_steps: int | None = 4096¶
- throw: bool = True¶
- solver_state: Any | None = None¶
- controller_state: Any | None = None¶
- made_jump: bool | None = None¶
- optional_override(*, t_start=None, t_end=None, dt=None, saveat=None, solver_state=None, controller_state=None)[source]¶
Create new config with optionally overridden parameters.
- Parameters:
t_start (
float|None) – Override integration start time.t_end (
float|None) – Override integration end time.dt (
float|None) – Override step size.saveat (
SaveAt|None) – Override save configuration.solver_state (
Any|None) – Override solver internal state.controller_state (
Any|None) – Override step size controller state.
- Returns:
New DiffraxConfig with specified parameters overridden.
- solve(terms, y0, args)[source]¶
Solve ODE using configured diffrax solver.
- Parameters:
terms – Diffrax ODE terms defining the vector field.
y0 – Initial condition.
args – Additional arguments passed to the vector field.
- Returns:
Diffrax solution object containing integration results.
- solve_sde(drift, diffusion, y0, rng, args=None, *, solver=None, levy_area=<class 'diffrax.BrownianIncrement'>, noise_transform=None)[source]¶
Solve SDE using configured parameters.
Solves the Itô SDE: dy = drift(t, y, args) dt + diffusion(t, y, args) dW
- Parameters:
drift – Drift function (t, y, args) -> dy_drift.
diffusion – Diffusion function (t, y, args) -> noise_scale.
y0 – Initial condition.
rng (
Array) – Random key for Brownian motion.args – Additional arguments passed to drift and diffusion.
solver (
AbstractSolver|None) – Override solver (default: Euler for SDE).levy_area (
type) – Levy area type for Brownian motion.noise_transform (
Callable|None) – Optional transform applied to Brownian increments. The SDE becomes: dy = drift dt + diffusion * noise_transform(dW)
- Returns:
Diffrax solution object containing integration results.
Note
SDE solving uses DirectAdjoint regardless of the configured adjoint, as other adjoint methods are not compatible with stochastic terms.