Proper scoring rules for probabilistic (ensemble) trajectory forecasts.
Implements four scoring rules (see Pic et al., 2025) for ensemble forecasts
of shape (S, T, 2) evaluated against an observed trajectory (T, 2):
squared_error()— deterministic-mean squared distance.dawid_sebastiani()— Gaussian-likelihood-based, no kernel.energy_score()— kernel-based proper scoring rule (unbiased estimator).variogram_score()— component-wise pairwise-difference score.
All scores follow the negative orientation convention: lower is better.
Each score accepts a reduce argument:
reduce=Nonereturns the per-time score of shape(T,).reduce="last"returns the scalar score at the final time.reduce="sum"returns(weights * score).sum(), defaulting to a uniform sum whenweightsisNone. By Proposition 2 of Pic et al., a non-negative-weighted sum of proper scoring rules is itself proper.
The default distance kernel for squared_error() and
energy_score() is the Euclidean distance. A user may pass any callable
satisfying the broadcasting kernel contract —
notably pastax.metric.separation_distance()
for great-circle distances on the sphere.
squared_error¶
- pastax.score.squared_error(forecast, obs, *, kernel=l2_distance, reduce=None, weights=None)
Squared distance between ensemble mean and observation.
With the default L2 kernel this is the squared error of the ensemble mean (Pic et al. 2025, Eq. 11).
- Parameters
forecast(Float[jaxlib._jax.Array,'S T 2']) – Ensemble forecast, shape(S, T, 2).obs(Float[jaxlib._jax.Array,'T 2']) – Observed trajectory, shape(T, 2).kernel(Callable[[Float[jaxlib._jax.Array,'...2'],Float[jaxlib._jax.Array,'...2']],Float[jaxlib._jax.Array,'...']]) – Broadcasting distance kernel. Defaults tol2_distance().reduce(Literal['last','sum']|None) – Time reduction.Nonereturns the per-time vector;"last"returns the scalar at the final time;"sum"returns the (optionally weighted) sum over time.weights(Float[jaxlib._jax.Array,'T']|None) – Per-time weights forreduce="sum"; ignored otherwise.
- Returns
Per-time score of shape
(T,)or a scalar, perreduce.- Return type
Float[jaxlib._jax.Array, ‘T’] | Float[jaxlib._jax.Array, ‘’]
dawid_sebastiani¶
- pastax.score.dawid_sebastiani(forecast, obs, *, reduce=None, weights=None)
Dawid-Sebastiani score: Gaussian log-likelihood of the observation under the ensemble.
The per-time score is
where is the unbiased (
ddof=1) sample covariance of the ensemble at time . Requires for to be a.s. full-rank on ; for the score is undefined (singular covariance).- Parameters
forecast(Float[jaxlib._jax.Array,'S T 2']) – Ensemble forecast, shape(S, T, 2), withS >= 3.obs(Float[jaxlib._jax.Array,'T 2']) – Observed trajectory, shape(T, 2).reduce(Literal['last','sum']|None) – Seesquared_error().weights(Float[jaxlib._jax.Array,'T']|None) – Seesquared_error().
- Returns
Per-time score of shape
(T,)or a scalar, perreduce.- Return type
Float[jaxlib._jax.Array, ‘T’] | Float[jaxlib._jax.Array, ‘’]
energy_score¶
- pastax.score.energy_score(forecast, obs, *, kernel=l2_distance, alpha=1.0, reduce=None, weights=None)
Energy score (Pic et al. 2025, Eq. 12) — unbiased Monte Carlo estimator.
The pairwise term is computed as a full
(S, S)mean (including the zero diagonal) multiplied byS/(S-1), which recovers the unbiased off-diagonal estimator exactly. Strictly proper for the L2 kernel and ; propriety with other kernels is not guaranteed.- Parameters
forecast(Float[jaxlib._jax.Array,'S T 2']) – Ensemble forecast, shape(S, T, 2), withS >= 2.obs(Float[jaxlib._jax.Array,'T 2']) – Observed trajectory, shape(T, 2).kernel(Callable[[Float[jaxlib._jax.Array,'...2'],Float[jaxlib._jax.Array,'...2']],Float[jaxlib._jax.Array,'...']]) – Broadcasting distance kernel. Defaults tol2_distance().alpha(float) – Distance exponent (typically in(0, 2)). Default1.0.reduce(Literal['last','sum']|None) – Seesquared_error().weights(Float[jaxlib._jax.Array,'T']|None) – Seesquared_error().
- Returns
Per-time score of shape
(T,)or a scalar, perreduce.- Return type
Float[jaxlib._jax.Array, ‘T’] | Float[jaxlib._jax.Array, ‘’]
variogram_score¶
- pastax.score.variogram_score(forecast, obs, *, p=2.0, component_weights=None, reduce=None, weights=None)
Variogram score of order
p(Pic et al. 2025, Eq. 13).Sums over both component pairs ; with the default
component_weights = 1 - I, the diagonal contribution (zero) is masked out and the off-diagonal pair is counted twice (symmetric formulation).- Parameters
forecast(Float[jaxlib._jax.Array,'S T 2']) – Ensemble forecast, shape(S, T, 2).obs(Float[jaxlib._jax.Array,'T 2']) – Observed trajectory, shape(T, 2).p(float) – Variogram order. Default2.0.component_weights(Float[jaxlib._jax.Array,'2 2']|None) –(2, 2)non-negative weight matrix. Defaults toones((2, 2)) - eye(2).reduce(Literal['last','sum']|None) – Seesquared_error().weights(Float[jaxlib._jax.Array,'T']|None) – Seesquared_error().
- Returns
Per-time score of shape
(T,)or a scalar, perreduce.- Return type
Float[jaxlib._jax.Array, ‘T’] | Float[jaxlib._jax.Array, ‘’]