Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

pastax.solver

ODE and SDE solvers with a unified lax.scan integration loop.

Solver lineup

ODE solvers (both ode_step and sde_step):

  • Euler, Heun, RK4 — first / second / fourth-order explicit Runge–Kutta. sde_step interprets the term as a Stratonovich predictor–corrector that reuses the same Wiener increment across stages.

ODE-only solvers (raise on sde_step):

  • Tsit5 — Tsitouras 5(4)6 explicit RK, order 5 (fixed-step).

  • Dopri5 — Dormand–Prince 5(4)7 explicit RK (FSAL), order 5 (fixed-step).

SDE-only solvers (raise on ode_step):

  • EulerHeun — diffrax-style Stratonovich predictor–corrector (diffusion-only predictor, Euler drift). Strong order 1.0.

  • ItoMilstein, StratonovichMilstein — diagonal-noise Milstein schemes. Strong order 1.0. g must have shape (state_dim,); matrix diffusion raises.

State

The state y may be any PyTree (a bare array is the single-leaf special case, e.g. [lat, lon]). The trajectory returned by solve() has the same PyTree structure as y0 with a leading n_save + 1 axis on every leaf. A PyTree state makes second-order dynamics natural — carry (x, v) and let the term return (dx, dv) = (v, f(x, t)) (see solve()).

Term API

ODE term: f(t, y, args[, ctrl]) -> dy returns the time derivative dy as a PyTree with the same structure as y (for a flat state, the velocity [dlat/dt, dlon/dt] in degrees/second).

SDE term: f(t, y, args[, ctrl]) -> (drift, diffusion). drift is a PyTree matching y; diffusion maps the Wiener increment to a y-shaped tangent and may be:

  • a PyTree matching the noise structure — diagonal noise, applied leafwise as g * dW (for a flat state, g.shape == (state,));

  • a bare 2-D array (state, n_noise) — applied as g @ dW;

  • a lineax.AbstractLinearOperator — applied as g.mv(dW) (general / matrix / cross-leaf noise; requires the optional lineax dependency).

The solver applies it as dy=driftdt+gdWdy = \mathrm{drift}\,dt + g\,dW with dW=dtzdW = \sqrt{|dt|}\,z drawn internally; the term never receives z. The Milstein solvers require a flat array state with diagonal g.

The optional ctrl argument is present when controls is passed to solve(); the solver slices controls[i] at each step and forwards it to the term. The term owns all interpretation and scaling of the slice.

Noise convention

Per-step Wiener increment is dW=dtzdW = \sqrt{|dt|}\,z with z a standard normal drawn internally; the SDE term never sees z. The noise is sampled to match y0’s structure by default, or an explicit brownian_structure (a PyTree of jax.ShapeDtypeStruct) when the noise space differs from the state space (e.g. driving an n-dim state with an m-dim Brownian motion through a lineax operator). Passing key to solve() activates SDE mode. A single trajectory is produced by default; pass n_samples > 1 for an ensemble of independent realisations (an extra leading n_samples axis via internal vmap over split keys).

Backwards-in-time integration is supported for all solvers: pass a negative int_dt (and matching negative save_dt) to solve(). SDE backwards integration is not a textbook construction, but remains finite because the solver sign-abs-normalises the dt\sqrt{dt} factor.

AbstractSolver

class pastax.solver.AbstractSolver

Bases: Module

Abstract base class for fixed-step ODE/SDE solvers.

Subclasses implement ode_step() (deterministic) and sde_step() (stochastic, with a pre-sampled z). Solvers that are specific to one mode raise NotImplementedError from the other.

abstractmethod ode_step(term, t, y, dt, args, ctrl)

Advance the ODE state by one step of size dt.

Parameters
  • term (Callable)Drift callable f(t, y, args) -> Float[Array, "2"] returning velocity in degrees per second.

  • t (Float[jaxlib._jax.Array, ''])Current time, in seconds.

  • y (Float[jaxlib._jax.Array, '2'])Current state [lat, lon] in degrees.

  • dt (Float[jaxlib._jax.Array, ''])Step size in seconds.

  • args (PyTree)Arbitrary fixed Pytree forwarded to term.

  • ctrl (PyTree)Arbitrary time-varying Pytree forwarded to term.

Returns

Updated state [lat, lon] in degrees after one step.

Return type

Float[jaxlib._jax.Array, ‘2’]

abstractmethod sde_step(term, t, y, dt, args, ctrl, z)

Advance the SDE state by one step using a pre-sampled z.

Parameters
  • term (Callable)Stochastic dynamics callable f(t, y, args) -> (drift, g). drift is the deterministic velocity in degrees per second; g is the diffusion coefficient with shape (2,) (diagonal) or (2, 2) (full matrix). The term never sees z; the Wiener increment is applied by the solver.

  • t (Float[jaxlib._jax.Array, ''])Current time, in seconds.

  • y (Float[jaxlib._jax.Array, '2'])Current state [lat, lon] in degrees.

  • dt (Float[jaxlib._jax.Array, ''])Step size in seconds.

  • args (PyTree)Arbitrary fixed Pytree forwarded to term.

  • ctrl (PyTree)Arbitrary time-varying Pytree forwarded to term.

  • z (Float[jaxlib._jax.Array, 'n_noise'])Standard-normal noise sample of shape (n_noise,). The Wiener increment used by the solver is dW=dtzdW = \sqrt{|dt|}\,z.

Returns

Updated state [lat, lon] in degrees after one step.

Return type

Float[jaxlib._jax.Array, ‘2’]

Euler

class pastax.solver.Euler

Bases: AbstractSolver

Explicit Euler / Euler–Maruyama solver (first-order, fixed-step).

ode_step(term, t, y, dt, args, ctrl)

One Euler step: y_new = y + term(t, y, args) * dt.

Parameters
  • term (Callable)

  • t (Float[jaxlib._jax.Array, ''])

  • y (Float[jaxlib._jax.Array, '2'])

  • dt (Float[jaxlib._jax.Array, ''])

  • args (PyTree)

  • ctrl (PyTree)

Return type

Float[jaxlib._jax.Array, ‘2’]

sde_step(term, t, y, dt, args, ctrl, z)

One Euler–Maruyama step: y+driftdt+gdWy + \mathrm{drift}\,dt + g\,dW.

Parameters
  • term (Callable)

  • t (Float[jaxlib._jax.Array, ''])

  • y (Float[jaxlib._jax.Array, '2'])

  • dt (Float[jaxlib._jax.Array, ''])

  • args (PyTree)

  • ctrl (PyTree)

  • z (Float[jaxlib._jax.Array, 'n_noise'])

Return type

Float[jaxlib._jax.Array, ‘2’]

Heun

class pastax.solver.Heun

Bases: AbstractSolver

Heun (explicit second-order, two-stage Runge–Kutta) solver.

Convergence order 2 in the ODE case. The SDE step is a Stratonovich predictor–corrector that reuses the same dW in both stages.

ode_step(term, t, y, dt, args, ctrl)

One Heun (trapezoidal) step.

Parameters
  • term (Callable)

  • t (Float[jaxlib._jax.Array, ''])

  • y (Float[jaxlib._jax.Array, '2'])

  • dt (Float[jaxlib._jax.Array, ''])

  • args (PyTree)

  • ctrl (PyTree)

Return type

Float[jaxlib._jax.Array, ‘2’]

sde_step(term, t, y, dt, args, ctrl, z)

One Stratonovich Heun step (same dW in predictor and corrector).

Parameters
  • term (Callable)

  • t (Float[jaxlib._jax.Array, ''])

  • y (Float[jaxlib._jax.Array, '2'])

  • dt (Float[jaxlib._jax.Array, ''])

  • args (PyTree)

  • ctrl (PyTree)

  • z (Float[jaxlib._jax.Array, 'n_noise'])

Return type

Float[jaxlib._jax.Array, ‘2’]

RK4

class pastax.solver.RK4

Bases: AbstractSolver

Classical fourth-order Runge–Kutta solver (four stages, fixed-step).

Convergence order 4 in the ODE case. The SDE step reuses the same dW across all four stages, yielding a Stratonovich-consistent scheme whose strong order is limited by the noise structure.

ode_step(term, t, y, dt, args, ctrl)

One classical RK4 step.

Parameters
  • term (Callable)

  • t (Float[jaxlib._jax.Array, ''])

  • y (Float[jaxlib._jax.Array, '2'])

  • dt (Float[jaxlib._jax.Array, ''])

  • args (PyTree)

  • ctrl (PyTree)

Return type

Float[jaxlib._jax.Array, ‘2’]

sde_step(term, t, y, dt, args, ctrl, z)

One stochastic RK4 step (Stratonovich, single dW across stages).

Parameters
  • term (Callable)

  • t (Float[jaxlib._jax.Array, ''])

  • y (Float[jaxlib._jax.Array, '2'])

  • dt (Float[jaxlib._jax.Array, ''])

  • args (PyTree)

  • ctrl (PyTree)

  • z (Float[jaxlib._jax.Array, 'n_noise'])

Return type

Float[jaxlib._jax.Array, ‘2’]

Tsit5

class pastax.solver.Tsit5

Bases: AbstractSolver

Tsitouras 5(4)6 explicit Runge–Kutta (ODE-only, fixed-step, order 5).

Six stages, no embedded error estimator (the 4th-order companion row of Tsitouras 2011 is unused since we are fixed-step).

ode_step(term, t, y, dt, args, ctrl)

One Tsit5 step (5th-order weights only).

Parameters
  • term (Callable)

  • t (Float[jaxlib._jax.Array, ''])

  • y (Float[jaxlib._jax.Array, '2'])

  • dt (Float[jaxlib._jax.Array, ''])

  • args (PyTree)

  • ctrl (PyTree)

Return type

Float[jaxlib._jax.Array, ‘2’]

sde_step(term, t, y, dt, args, ctrl, z)

Advance the SDE state by one step using a pre-sampled z.

Parameters
  • term (Callable)Stochastic dynamics callable f(t, y, args) -> (drift, g). drift is the deterministic velocity in degrees per second; g is the diffusion coefficient with shape (2,) (diagonal) or (2, 2) (full matrix). The term never sees z; the Wiener increment is applied by the solver.

  • t (Float[jaxlib._jax.Array, ''])Current time, in seconds.

  • y (Float[jaxlib._jax.Array, '2'])Current state [lat, lon] in degrees.

  • dt (Float[jaxlib._jax.Array, ''])Step size in seconds.

  • args (PyTree)Arbitrary fixed Pytree forwarded to term.

  • ctrl (PyTree)Arbitrary time-varying Pytree forwarded to term.

  • z (Float[jaxlib._jax.Array, 'n_noise'])Standard-normal noise sample of shape (n_noise,). The Wiener increment used by the solver is dW=dtzdW = \sqrt{|dt|}\,z.

Returns

Updated state [lat, lon] in degrees after one step.

Return type

Float[jaxlib._jax.Array, ‘2’]

Dopri5

class pastax.solver.Dopri5

Bases: AbstractSolver

Dormand–Prince 5(4)7 explicit Runge–Kutta (ODE-only, fixed-step, order 5).

Seven stages with the first-same-as-last property; here we use the 5th-order row only.

ode_step(term, t, y, dt, args, ctrl)

One Dopri5 step (5th-order weights only).

Parameters
  • term (Callable)

  • t (Float[jaxlib._jax.Array, ''])

  • y (Float[jaxlib._jax.Array, '2'])

  • dt (Float[jaxlib._jax.Array, ''])

  • args (PyTree)

  • ctrl (PyTree)

Return type

Float[jaxlib._jax.Array, ‘2’]

sde_step(term, t, y, dt, args, ctrl, z)

Advance the SDE state by one step using a pre-sampled z.

Parameters
  • term (Callable)Stochastic dynamics callable f(t, y, args) -> (drift, g). drift is the deterministic velocity in degrees per second; g is the diffusion coefficient with shape (2,) (diagonal) or (2, 2) (full matrix). The term never sees z; the Wiener increment is applied by the solver.

  • t (Float[jaxlib._jax.Array, ''])Current time, in seconds.

  • y (Float[jaxlib._jax.Array, '2'])Current state [lat, lon] in degrees.

  • dt (Float[jaxlib._jax.Array, ''])Step size in seconds.

  • args (PyTree)Arbitrary fixed Pytree forwarded to term.

  • ctrl (PyTree)Arbitrary time-varying Pytree forwarded to term.

  • z (Float[jaxlib._jax.Array, 'n_noise'])Standard-normal noise sample of shape (n_noise,). The Wiener increment used by the solver is dW=dtzdW = \sqrt{|dt|}\,z.

Returns

Updated state [lat, lon] in degrees after one step.

Return type

Float[jaxlib._jax.Array, ‘2’]

EulerHeun

class pastax.solver.EulerHeun

Bases: AbstractSolver

Stochastic Euler–Heun solver (SDE-only, Stratonovich, strong order 1.0).

Matches diffrax’s EulerHeun algorithm: the predictor uses diffusion only and the drift is applied once Euler-style. Accepts both diagonal (g.shape == (2,)) and full (g.shape == (2, 2)) diffusion shapes.

ode_step(term, t, y, dt, args, ctrl)

Advance the ODE state by one step of size dt.

Parameters
  • term (Callable)Drift callable f(t, y, args) -> Float[Array, "2"] returning velocity in degrees per second.

  • t (Float[jaxlib._jax.Array, ''])Current time, in seconds.

  • y (Float[jaxlib._jax.Array, '2'])Current state [lat, lon] in degrees.

  • dt (Float[jaxlib._jax.Array, ''])Step size in seconds.

  • args (PyTree)Arbitrary fixed Pytree forwarded to term.

  • ctrl (PyTree)Arbitrary time-varying Pytree forwarded to term.

Returns

Updated state [lat, lon] in degrees after one step.

Return type

Float[jaxlib._jax.Array, ‘2’]

sde_step(term, t, y, dt, args, ctrl, z)

One stochastic Euler–Heun step.

Parameters
  • term (Callable)

  • t (Float[jaxlib._jax.Array, ''])

  • y (Float[jaxlib._jax.Array, '2'])

  • dt (Float[jaxlib._jax.Array, ''])

  • args (PyTree)

  • ctrl (PyTree)

  • z (Float[jaxlib._jax.Array, 'n_noise'])

Return type

Float[jaxlib._jax.Array, ‘2’]

ItoMilstein

class pastax.solver.ItoMilstein

Bases: AbstractSolver

Itô Milstein solver (SDE-only, diagonal noise, strong order 1.0).

Requires g.shape == (2,). Raises NotImplementedError for matrix-valued g — use EulerHeun for general noise.

ode_step(term, t, y, dt, args, ctrl)

Advance the ODE state by one step of size dt.

Parameters
  • term (Callable)Drift callable f(t, y, args) -> Float[Array, "2"] returning velocity in degrees per second.

  • t (Float[jaxlib._jax.Array, ''])Current time, in seconds.

  • y (Float[jaxlib._jax.Array, '2'])Current state [lat, lon] in degrees.

  • dt (Float[jaxlib._jax.Array, ''])Step size in seconds.

  • args (PyTree)Arbitrary fixed Pytree forwarded to term.

  • ctrl (PyTree)Arbitrary time-varying Pytree forwarded to term.

Returns

Updated state [lat, lon] in degrees after one step.

Return type

Float[jaxlib._jax.Array, ‘2’]

sde_step(term, t, y, dt, args, ctrl, z)

One Itô Milstein step: y+fdt+gdW+12g(g/y)(dW2dt)y + f\,dt + g\,dW + \tfrac12\,g\,(\partial g/\partial y)\,(dW^2 - dt).

Parameters
  • term (Callable)

  • t (Float[jaxlib._jax.Array, ''])

  • y (Float[jaxlib._jax.Array, '2'])

  • dt (Float[jaxlib._jax.Array, ''])

  • args (PyTree)

  • ctrl (PyTree)

  • z (Float[jaxlib._jax.Array, 'n_noise'])

Return type

Float[jaxlib._jax.Array, ‘2’]

StratonovichMilstein

class pastax.solver.StratonovichMilstein

Bases: AbstractSolver

Stratonovich Milstein solver (SDE-only, diagonal noise, strong order 1.0).

Requires g.shape == (2,). Differs from ItoMilstein by the absence of the 12g(g/y)dt-\tfrac12\,g\,(\partial g/\partial y)\,dt Itô-to-Stratonovich correction.

ode_step(term, t, y, dt, args, ctrl)

Advance the ODE state by one step of size dt.

Parameters
  • term (Callable)Drift callable f(t, y, args) -> Float[Array, "2"] returning velocity in degrees per second.

  • t (Float[jaxlib._jax.Array, ''])Current time, in seconds.

  • y (Float[jaxlib._jax.Array, '2'])Current state [lat, lon] in degrees.

  • dt (Float[jaxlib._jax.Array, ''])Step size in seconds.

  • args (PyTree)Arbitrary fixed Pytree forwarded to term.

  • ctrl (PyTree)Arbitrary time-varying Pytree forwarded to term.

Returns

Updated state [lat, lon] in degrees after one step.

Return type

Float[jaxlib._jax.Array, ‘2’]

sde_step(term, t, y, dt, args, ctrl, z)

One Stratonovich Milstein step: y+fdt+gdW+12g(g/y)dW2y + f\,dt + g\,dW + \tfrac12\,g\,(\partial g/\partial y)\,dW^2.

Parameters
  • term (Callable)

  • t (Float[jaxlib._jax.Array, ''])

  • y (Float[jaxlib._jax.Array, '2'])

  • dt (Float[jaxlib._jax.Array, ''])

  • args (PyTree)

  • ctrl (PyTree)

  • z (Float[jaxlib._jax.Array, 'n_noise'])

Return type

Float[jaxlib._jax.Array, ‘2’]

solve

pastax.solver.solve(term, y0, t0, n_save, int_dt, save_dt, solver=None, args=None, controls=None, key=None, n_samples=1, brownian_structure=None, adjoint='checkpointed', checkpoints=None)

Integrate a trajectory for n_save output intervals starting at t0.

ODE mode (default, no key): term(t, y[, args, ctrl]) returns dy, the time derivative as a PyTree matching y. SDE mode (pass key): term(t, y[, args, ctrl]) returns (drift, diffusion); the solver draws a standard-normal z and applies dW=int_dtzdW = \sqrt{|\mathrm{int\_dt}|}\,z internally. The optional ctrl argument is present when controls is provided — the solver slices it at each step; the term owns its interpretation.

The state y may be any PyTree (a bare array is the single-leaf case). For second-order dynamics dv = f(x, t) dt + noise, dx = v dt, carry y = (x, v) (e.g. a NamedTuple) and return (v, f(x, t)) from the term; put the noise on the velocity leaf only.

The solver runs on a fine integration grid of n_fine = n_save * n_substeps steps (where n_substeps = round(save_dt / int_dt)), then slices every n_substeps steps to produce the n_save + 1 saved states.

Ensemble: pass n_samples > 1 in SDE mode; the key is split internally. Perturbed ODE: use ODE+controls and jax.vmap(lambda c: solve(..., controls=c))(controls_batch).

Parameters
  • term (Callable)Dynamics callable f(t, y[, args, ctrl]). ODE: returns dy (PyTree matching y). SDE: returns (drift, diffusion) where diffusion is a PyTree matching the noise (diagonal), a 2-D array (state, n_noise) (matrix), or a lineax linear operator.

  • y0 (Float[jaxlib._jax.Array, '2'])Initial state. Any PyTree; a bare array (e.g. [lat, lon], shape (2,)) is the single-leaf case. Defines the output structure.

  • t0 (Float[jaxlib._jax.Array, ''])Start time in seconds. JAX scalar — can change between calls without recompilation. The implicit end time is t0 + n_save * save_dt.

  • n_save (int)Number of output intervals (static). Each output leaf has a leading n_save + 1 axis including the initial state.

  • int_dt (float)Integration step size in seconds (static). Use a negative value for backward-in-time integration.

  • save_dt (float)Output interval in seconds (static). Must be an integer multiple of int_dt (same sign). n_substeps = round(save_dt / int_dt) >= 1.

  • solver (AbstractSolver | None)Solver instance. Defaults to Heun().

  • args (PyTree | None)Arbitrary fixed Pytree passed through to term (e.g. a Dataset).

  • controls (PyTree | None)Arbitrary per-step Pytree with leading axis n_fine. Sliced at each integration step.

  • key (Key[jaxlib._jax.Array, ''] | None)PRNG key for SDE mode. When provided, draws a standard-normal noise sequence (shaped per brownian_structure) and runs in SDE mode.

  • n_samples (int)Number of independent SDE realisations (default 1). Ignored in ODE mode. When > 1, the key is split and trajectories are vmapped, adding a leading n_samples axis to every output leaf.

  • brownian_structure (PyTree | None)Optional prototype PyTree of jax.ShapeDtypeStruct (or arrays) describing the Wiener process, used when the noise space differs from the state space (e.g. an m-dim Brownian motion driving an n-dim state via a lineax operator). Defaults to y0’s structure and per-leaf shapes.

  • adjoint (str)Differentiation strategy for the integration loop. "checkpointed" (default) uses binomial checkpointing (treeverse) for low reverse-mode memory, but is reverse-mode onlyjax.jvp is not supported. "forward" uses a plain jax.lax.scan (no per-step checkpoint) ideal for forward-mode AD (jax.jvp / jax.jacfwd); mirrors diffrax.ForwardMode. Reverse mode also works through this path at full O(n_fine) activation memory.

  • checkpoints (int | str | None)Memory knob for adjoint="checkpointed" (ignored otherwise). None (default) forwards to equinox.internal.scan’s built-in O(sqrt(n_fine)) Stumm–Walther online schedule — the same default that diffrax.RecursiveCheckpointAdjoint uses, balancing memory against backward recompute. A smaller int (e.g. ceil(log2(n_fine))) saves memory at the cost of more recompute; a larger int (or "all" for one per step) trades memory for less recompute.

Returns

A PyTree with the structure of y0; each leaf gains a leading n_save + 1 axis (and an extra leading n_samples axis in SDE mode with n_samples > 1). For a flat (2,) state this is shape (n_save + 1, 2) or (n_samples, n_save + 1, 2).

Return type

Array