ODE and SDE solvers with a unified lax.scan integration loop.
Solver lineup¶
ODE solvers (both ode_step and sde_step):
Euler,Heun,RK4— first / second / fourth-order explicit Runge–Kutta.sde_stepinterprets the term as a Stratonovich predictor–corrector that reuses the same Wiener increment across stages.
ODE-only solvers (raise on sde_step):
Tsit5— Tsitouras 5(4)6 explicit RK, order 5 (fixed-step).Dopri5— Dormand–Prince 5(4)7 explicit RK (FSAL), order 5 (fixed-step).
SDE-only solvers (raise on ode_step):
EulerHeun— diffrax-style Stratonovich predictor–corrector (diffusion-only predictor, Euler drift). Strong order 1.0.ItoMilstein,StratonovichMilstein— diagonal-noise Milstein schemes. Strong order 1.0.gmust have shape(state_dim,); matrix diffusion raises.
State¶
The state y may be any PyTree (a bare array is the single-leaf special
case, e.g. [lat, lon]). The trajectory returned by solve() has the
same PyTree structure as y0 with a leading n_save + 1 axis on every
leaf. A PyTree state makes second-order dynamics natural — carry (x, v) and
let the term return (dx, dv) = (v, f(x, t)) (see solve()).
Term API¶
ODE term: f(t, y, args[, ctrl]) -> dy returns the time derivative dy as
a PyTree with the same structure as y (for a flat state, the velocity
[dlat/dt, dlon/dt] in degrees/second).
SDE term: f(t, y, args[, ctrl]) -> (drift, diffusion). drift is a PyTree
matching y; diffusion maps the Wiener increment to a y-shaped tangent
and may be:
a PyTree matching the noise structure — diagonal noise, applied leafwise as
g * dW(for a flat state,g.shape == (state,));a bare 2-D array
(state, n_noise)— applied asg @ dW;a
lineax.AbstractLinearOperator— applied asg.mv(dW)(general / matrix / cross-leaf noise; requires the optionallineaxdependency).
The solver applies it as with
drawn internally; the term never receives z.
The Milstein solvers require a flat array state with diagonal g.
The optional ctrl argument is present when controls is passed to
solve(); the solver slices controls[i] at each step and forwards it
to the term. The term owns all interpretation and scaling of the slice.
Noise convention¶
Per-step Wiener increment is with z a standard
normal drawn internally; the SDE term never sees z. The noise is sampled to
match y0’s structure by default, or an explicit brownian_structure
(a PyTree of jax.ShapeDtypeStruct) when the noise space differs from the
state space (e.g. driving an n-dim state with an m-dim Brownian motion
through a lineax operator). Passing key to solve() activates SDE
mode. A single trajectory is produced by default; pass n_samples > 1 for an
ensemble of independent realisations (an extra leading n_samples axis via
internal vmap over split keys).
Backwards-in-time integration is supported for all solvers: pass a negative
int_dt (and matching negative save_dt) to solve(). SDE backwards
integration is not a textbook construction, but remains finite because the
solver sign-abs-normalises the factor.
AbstractSolver¶
- class pastax.solver.AbstractSolver
Bases:
ModuleAbstract base class for fixed-step ODE/SDE solvers.
Subclasses implement
ode_step()(deterministic) andsde_step()(stochastic, with a pre-sampledz). Solvers that are specific to one mode raiseNotImplementedErrorfrom the other.- abstractmethod ode_step(term, t, y, dt, args, ctrl)
Advance the ODE state by one step of size
dt.- Parameters
term(Callable) – Drift callablef(t, y, args) -> Float[Array, "2"]returning velocity in degrees per second.t(Float[jaxlib._jax.Array,'']) – Current time, in seconds.y(Float[jaxlib._jax.Array,'2']) – Current state[lat, lon]in degrees.dt(Float[jaxlib._jax.Array,'']) – Step size in seconds.args(PyTree) – Arbitrary fixed Pytree forwarded toterm.ctrl(PyTree) – Arbitrary time-varying Pytree forwarded toterm.
- Returns
Updated state
[lat, lon]in degrees after one step.- Return type
Float[jaxlib._jax.Array, ‘2’]
- abstractmethod sde_step(term, t, y, dt, args, ctrl, z)
Advance the SDE state by one step using a pre-sampled
z.- Parameters
term(Callable) – Stochastic dynamics callablef(t, y, args) -> (drift, g).driftis the deterministic velocity in degrees per second;gis the diffusion coefficient with shape(2,)(diagonal) or(2, 2)(full matrix). The term never seesz; the Wiener increment is applied by the solver.t(Float[jaxlib._jax.Array,'']) – Current time, in seconds.y(Float[jaxlib._jax.Array,'2']) – Current state[lat, lon]in degrees.dt(Float[jaxlib._jax.Array,'']) – Step size in seconds.args(PyTree) – Arbitrary fixed Pytree forwarded toterm.ctrl(PyTree) – Arbitrary time-varying Pytree forwarded toterm.z(Float[jaxlib._jax.Array,'n_noise']) – Standard-normal noise sample of shape(n_noise,). The Wiener increment used by the solver is .
- Returns
Updated state
[lat, lon]in degrees after one step.- Return type
Float[jaxlib._jax.Array, ‘2’]
Euler¶
- class pastax.solver.Euler
Bases:
AbstractSolverExplicit Euler / Euler–Maruyama solver (first-order, fixed-step).
- ode_step(term, t, y, dt, args, ctrl)
One Euler step:
y_new = y + term(t, y, args) * dt.- Parameters
term(Callable)t(Float[jaxlib._jax.Array,''])y(Float[jaxlib._jax.Array,'2'])dt(Float[jaxlib._jax.Array,''])args(PyTree)ctrl(PyTree)
- Return type
Float[jaxlib._jax.Array, ‘2’]
- sde_step(term, t, y, dt, args, ctrl, z)
One Euler–Maruyama step: .
- Parameters
term(Callable)t(Float[jaxlib._jax.Array,''])y(Float[jaxlib._jax.Array,'2'])dt(Float[jaxlib._jax.Array,''])args(PyTree)ctrl(PyTree)z(Float[jaxlib._jax.Array,'n_noise'])
- Return type
Float[jaxlib._jax.Array, ‘2’]
Heun¶
- class pastax.solver.Heun
Bases:
AbstractSolverHeun (explicit second-order, two-stage Runge–Kutta) solver.
Convergence order 2 in the ODE case. The SDE step is a Stratonovich predictor–corrector that reuses the same
dWin both stages.- ode_step(term, t, y, dt, args, ctrl)
One Heun (trapezoidal) step.
- Parameters
term(Callable)t(Float[jaxlib._jax.Array,''])y(Float[jaxlib._jax.Array,'2'])dt(Float[jaxlib._jax.Array,''])args(PyTree)ctrl(PyTree)
- Return type
Float[jaxlib._jax.Array, ‘2’]
- sde_step(term, t, y, dt, args, ctrl, z)
One Stratonovich Heun step (same
dWin predictor and corrector).- Parameters
term(Callable)t(Float[jaxlib._jax.Array,''])y(Float[jaxlib._jax.Array,'2'])dt(Float[jaxlib._jax.Array,''])args(PyTree)ctrl(PyTree)z(Float[jaxlib._jax.Array,'n_noise'])
- Return type
Float[jaxlib._jax.Array, ‘2’]
RK4¶
- class pastax.solver.RK4
Bases:
AbstractSolverClassical fourth-order Runge–Kutta solver (four stages, fixed-step).
Convergence order 4 in the ODE case. The SDE step reuses the same
dWacross all four stages, yielding a Stratonovich-consistent scheme whose strong order is limited by the noise structure.- ode_step(term, t, y, dt, args, ctrl)
One classical RK4 step.
- Parameters
term(Callable)t(Float[jaxlib._jax.Array,''])y(Float[jaxlib._jax.Array,'2'])dt(Float[jaxlib._jax.Array,''])args(PyTree)ctrl(PyTree)
- Return type
Float[jaxlib._jax.Array, ‘2’]
- sde_step(term, t, y, dt, args, ctrl, z)
One stochastic RK4 step (Stratonovich, single
dWacross stages).- Parameters
term(Callable)t(Float[jaxlib._jax.Array,''])y(Float[jaxlib._jax.Array,'2'])dt(Float[jaxlib._jax.Array,''])args(PyTree)ctrl(PyTree)z(Float[jaxlib._jax.Array,'n_noise'])
- Return type
Float[jaxlib._jax.Array, ‘2’]
Tsit5¶
- class pastax.solver.Tsit5
Bases:
AbstractSolverTsitouras 5(4)6 explicit Runge–Kutta (ODE-only, fixed-step, order 5).
Six stages, no embedded error estimator (the 4th-order companion row of Tsitouras 2011 is unused since we are fixed-step).
- ode_step(term, t, y, dt, args, ctrl)
One Tsit5 step (5th-order weights only).
- Parameters
term(Callable)t(Float[jaxlib._jax.Array,''])y(Float[jaxlib._jax.Array,'2'])dt(Float[jaxlib._jax.Array,''])args(PyTree)ctrl(PyTree)
- Return type
Float[jaxlib._jax.Array, ‘2’]
- sde_step(term, t, y, dt, args, ctrl, z)
Advance the SDE state by one step using a pre-sampled
z.- Parameters
term(Callable) – Stochastic dynamics callablef(t, y, args) -> (drift, g).driftis the deterministic velocity in degrees per second;gis the diffusion coefficient with shape(2,)(diagonal) or(2, 2)(full matrix). The term never seesz; the Wiener increment is applied by the solver.t(Float[jaxlib._jax.Array,'']) – Current time, in seconds.y(Float[jaxlib._jax.Array,'2']) – Current state[lat, lon]in degrees.dt(Float[jaxlib._jax.Array,'']) – Step size in seconds.args(PyTree) – Arbitrary fixed Pytree forwarded toterm.ctrl(PyTree) – Arbitrary time-varying Pytree forwarded toterm.z(Float[jaxlib._jax.Array,'n_noise']) – Standard-normal noise sample of shape(n_noise,). The Wiener increment used by the solver is .
- Returns
Updated state
[lat, lon]in degrees after one step.- Return type
Float[jaxlib._jax.Array, ‘2’]
Dopri5¶
- class pastax.solver.Dopri5
Bases:
AbstractSolverDormand–Prince 5(4)7 explicit Runge–Kutta (ODE-only, fixed-step, order 5).
Seven stages with the first-same-as-last property; here we use the 5th-order row only.
- ode_step(term, t, y, dt, args, ctrl)
One Dopri5 step (5th-order weights only).
- Parameters
term(Callable)t(Float[jaxlib._jax.Array,''])y(Float[jaxlib._jax.Array,'2'])dt(Float[jaxlib._jax.Array,''])args(PyTree)ctrl(PyTree)
- Return type
Float[jaxlib._jax.Array, ‘2’]
- sde_step(term, t, y, dt, args, ctrl, z)
Advance the SDE state by one step using a pre-sampled
z.- Parameters
term(Callable) – Stochastic dynamics callablef(t, y, args) -> (drift, g).driftis the deterministic velocity in degrees per second;gis the diffusion coefficient with shape(2,)(diagonal) or(2, 2)(full matrix). The term never seesz; the Wiener increment is applied by the solver.t(Float[jaxlib._jax.Array,'']) – Current time, in seconds.y(Float[jaxlib._jax.Array,'2']) – Current state[lat, lon]in degrees.dt(Float[jaxlib._jax.Array,'']) – Step size in seconds.args(PyTree) – Arbitrary fixed Pytree forwarded toterm.ctrl(PyTree) – Arbitrary time-varying Pytree forwarded toterm.z(Float[jaxlib._jax.Array,'n_noise']) – Standard-normal noise sample of shape(n_noise,). The Wiener increment used by the solver is .
- Returns
Updated state
[lat, lon]in degrees after one step.- Return type
Float[jaxlib._jax.Array, ‘2’]
EulerHeun¶
- class pastax.solver.EulerHeun
Bases:
AbstractSolverStochastic Euler–Heun solver (SDE-only, Stratonovich, strong order 1.0).
Matches diffrax’s
EulerHeunalgorithm: the predictor uses diffusion only and the drift is applied once Euler-style. Accepts both diagonal (g.shape == (2,)) and full (g.shape == (2, 2)) diffusion shapes.- ode_step(term, t, y, dt, args, ctrl)
Advance the ODE state by one step of size
dt.- Parameters
term(Callable) – Drift callablef(t, y, args) -> Float[Array, "2"]returning velocity in degrees per second.t(Float[jaxlib._jax.Array,'']) – Current time, in seconds.y(Float[jaxlib._jax.Array,'2']) – Current state[lat, lon]in degrees.dt(Float[jaxlib._jax.Array,'']) – Step size in seconds.args(PyTree) – Arbitrary fixed Pytree forwarded toterm.ctrl(PyTree) – Arbitrary time-varying Pytree forwarded toterm.
- Returns
Updated state
[lat, lon]in degrees after one step.- Return type
Float[jaxlib._jax.Array, ‘2’]
- sde_step(term, t, y, dt, args, ctrl, z)
One stochastic Euler–Heun step.
- Parameters
term(Callable)t(Float[jaxlib._jax.Array,''])y(Float[jaxlib._jax.Array,'2'])dt(Float[jaxlib._jax.Array,''])args(PyTree)ctrl(PyTree)z(Float[jaxlib._jax.Array,'n_noise'])
- Return type
Float[jaxlib._jax.Array, ‘2’]
ItoMilstein¶
- class pastax.solver.ItoMilstein
Bases:
AbstractSolverItô Milstein solver (SDE-only, diagonal noise, strong order 1.0).
Requires
g.shape == (2,). RaisesNotImplementedErrorfor matrix-valuedg— useEulerHeunfor general noise.- ode_step(term, t, y, dt, args, ctrl)
Advance the ODE state by one step of size
dt.- Parameters
term(Callable) – Drift callablef(t, y, args) -> Float[Array, "2"]returning velocity in degrees per second.t(Float[jaxlib._jax.Array,'']) – Current time, in seconds.y(Float[jaxlib._jax.Array,'2']) – Current state[lat, lon]in degrees.dt(Float[jaxlib._jax.Array,'']) – Step size in seconds.args(PyTree) – Arbitrary fixed Pytree forwarded toterm.ctrl(PyTree) – Arbitrary time-varying Pytree forwarded toterm.
- Returns
Updated state
[lat, lon]in degrees after one step.- Return type
Float[jaxlib._jax.Array, ‘2’]
- sde_step(term, t, y, dt, args, ctrl, z)
One Itô Milstein step: .
- Parameters
term(Callable)t(Float[jaxlib._jax.Array,''])y(Float[jaxlib._jax.Array,'2'])dt(Float[jaxlib._jax.Array,''])args(PyTree)ctrl(PyTree)z(Float[jaxlib._jax.Array,'n_noise'])
- Return type
Float[jaxlib._jax.Array, ‘2’]
StratonovichMilstein¶
- class pastax.solver.StratonovichMilstein
Bases:
AbstractSolverStratonovich Milstein solver (SDE-only, diagonal noise, strong order 1.0).
Requires
g.shape == (2,). Differs fromItoMilsteinby the absence of the Itô-to-Stratonovich correction.- ode_step(term, t, y, dt, args, ctrl)
Advance the ODE state by one step of size
dt.- Parameters
term(Callable) – Drift callablef(t, y, args) -> Float[Array, "2"]returning velocity in degrees per second.t(Float[jaxlib._jax.Array,'']) – Current time, in seconds.y(Float[jaxlib._jax.Array,'2']) – Current state[lat, lon]in degrees.dt(Float[jaxlib._jax.Array,'']) – Step size in seconds.args(PyTree) – Arbitrary fixed Pytree forwarded toterm.ctrl(PyTree) – Arbitrary time-varying Pytree forwarded toterm.
- Returns
Updated state
[lat, lon]in degrees after one step.- Return type
Float[jaxlib._jax.Array, ‘2’]
- sde_step(term, t, y, dt, args, ctrl, z)
One Stratonovich Milstein step: .
- Parameters
term(Callable)t(Float[jaxlib._jax.Array,''])y(Float[jaxlib._jax.Array,'2'])dt(Float[jaxlib._jax.Array,''])args(PyTree)ctrl(PyTree)z(Float[jaxlib._jax.Array,'n_noise'])
- Return type
Float[jaxlib._jax.Array, ‘2’]
solve¶
- pastax.solver.solve(term, y0, t0, n_save, int_dt, save_dt, solver=None, args=None, controls=None, key=None, n_samples=1, brownian_structure=None, adjoint='checkpointed', checkpoints=None)
Integrate a trajectory for
n_saveoutput intervals starting att0.ODE mode (default, no
key):term(t, y[, args, ctrl])returnsdy, the time derivative as a PyTree matchingy. SDE mode (passkey):term(t, y[, args, ctrl])returns(drift, diffusion); the solver draws a standard-normalzand applies internally. The optionalctrlargument is present whencontrolsis provided — the solver slices it at each step; the term owns its interpretation.The state
ymay be any PyTree (a bare array is the single-leaf case). For second-order dynamicsdv = f(x, t) dt + noise,dx = v dt, carryy = (x, v)(e.g. aNamedTuple) and return(v, f(x, t))from the term; put the noise on the velocity leaf only.The solver runs on a fine integration grid of
n_fine = n_save * n_substepssteps (wheren_substeps = round(save_dt / int_dt)), then slices everyn_substepssteps to produce then_save + 1saved states.Ensemble: pass
n_samples > 1in SDE mode; the key is split internally. Perturbed ODE: use ODE+controls andjax.vmap(lambda c: solve(..., controls=c))(controls_batch).- Parameters
term(Callable) – Dynamics callablef(t, y[, args, ctrl]). ODE: returnsdy(PyTree matchingy). SDE: returns(drift, diffusion)wherediffusionis a PyTree matching the noise (diagonal), a 2-D array(state, n_noise)(matrix), or alineaxlinear operator.y0(Float[jaxlib._jax.Array,'2']) – Initial state. Any PyTree; a bare array (e.g.[lat, lon], shape(2,)) is the single-leaf case. Defines the output structure.t0(Float[jaxlib._jax.Array,'']) – Start time in seconds. JAX scalar — can change between calls without recompilation. The implicit end time ist0 + n_save * save_dt.n_save(int) – Number of output intervals (static). Each output leaf has a leadingn_save + 1axis including the initial state.int_dt(float) – Integration step size in seconds (static). Use a negative value for backward-in-time integration.save_dt(float) – Output interval in seconds (static). Must be an integer multiple ofint_dt(same sign).n_substeps = round(save_dt / int_dt) >= 1.solver(AbstractSolver|None) – Solver instance. Defaults to Heun().args(PyTree|None) – Arbitrary fixed Pytree passed through to term (e.g. a Dataset).controls(PyTree|None) – Arbitrary per-step Pytree with leading axisn_fine. Sliced at each integration step.key(Key[jaxlib._jax.Array,'']|None) – PRNG key for SDE mode. When provided, draws a standard-normal noise sequence (shaped perbrownian_structure) and runs in SDE mode.n_samples(int) – Number of independent SDE realisations (default 1). Ignored in ODE mode. When > 1, the key is split and trajectories are vmapped, adding a leadingn_samplesaxis to every output leaf.brownian_structure(PyTree|None) – Optional prototype PyTree ofjax.ShapeDtypeStruct(or arrays) describing the Wiener process, used when the noise space differs from the state space (e.g. anm-dim Brownian motion driving ann-dim state via alineaxoperator). Defaults toy0’s structure and per-leaf shapes.adjoint(str) – Differentiation strategy for the integration loop."checkpointed"(default) uses binomial checkpointing (treeverse) for low reverse-mode memory, but is reverse-mode only —jax.jvpis not supported."forward"uses a plainjax.lax.scan(no per-step checkpoint) ideal for forward-mode AD (jax.jvp/jax.jacfwd); mirrorsdiffrax.ForwardMode. Reverse mode also works through this path at full O(n_fine) activation memory.checkpoints(int|str|None) – Memory knob foradjoint="checkpointed"(ignored otherwise).None(default) forwards toequinox.internal.scan’s built-inO(sqrt(n_fine))Stumm–Walther online schedule — the same default thatdiffrax.RecursiveCheckpointAdjointuses, balancing memory against backward recompute. A smaller int (e.g.ceil(log2(n_fine))) saves memory at the cost of more recompute; a larger int (or"all"for one per step) trades memory for less recompute.
- Returns
A PyTree with the structure of
y0; each leaf gains a leadingn_save + 1axis (and an extra leadingn_samplesaxis in SDE mode withn_samples > 1). For a flat(2,)state this is shape(n_save + 1, 2)or(n_samples, n_save + 1, 2).- Return type
Array