Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

pastax.score

Proper scoring rules for probabilistic (ensemble) trajectory forecasts.

Implements four scoring rules (see Pic et al., 2025) for ensemble forecasts of shape (S, T, 2) evaluated against an observed trajectory (T, 2):

All scores follow the negative orientation convention: lower is better.

Each score accepts a reduce argument:

  • reduce=None returns the per-time score of shape (T,).

  • reduce="last" returns the scalar score at the final time.

  • reduce="sum" returns (weights * score).sum(), defaulting to a uniform sum when weights is None. By Proposition 2 of Pic et al., a non-negative-weighted sum of proper scoring rules is itself proper.

The default distance kernel for squared_error() and energy_score() is the Euclidean distance. A user may pass any callable satisfying the broadcasting kernel contract — notably pastax.metric.separation_distance() for great-circle distances on the sphere.

squared_error

pastax.score.squared_error(forecast, obs, *, kernel=l2_distance, reduce=None, weights=None)

Squared distance between ensemble mean and observation.

SEt=kernel ⁣(meansforecast[s,t], obs[t])2\mathrm{SE}_t = \operatorname{kernel}\!\left( \operatorname{mean}_s \mathrm{forecast}[s, t],\ \mathrm{obs}[t]\right)^2

With the default L2 kernel this is the squared error of the ensemble mean (Pic et al. 2025, Eq. 11).

Parameters
  • forecast (Float[jaxlib._jax.Array, 'S T 2'])Ensemble forecast, shape (S, T, 2).

  • obs (Float[jaxlib._jax.Array, 'T 2'])Observed trajectory, shape (T, 2).

  • kernel (Callable[[Float[jaxlib._jax.Array, '... 2'], Float[jaxlib._jax.Array, '... 2']], Float[jaxlib._jax.Array, '...']])Broadcasting distance kernel. Defaults to l2_distance().

  • reduce (Literal['last', 'sum'] | None)Time reduction. None returns the per-time vector; "last" returns the scalar at the final time; "sum" returns the (optionally weighted) sum over time.

  • weights (Float[jaxlib._jax.Array, 'T'] | None)Per-time weights for reduce="sum"; ignored otherwise.

Returns

Per-time score of shape (T,) or a scalar, per reduce.

Return type

Float[jaxlib._jax.Array, ‘T’] | Float[jaxlib._jax.Array, ‘’]

dawid_sebastiani

pastax.score.dawid_sebastiani(forecast, obs, *, reduce=None, weights=None)

Dawid-Sebastiani score: Gaussian log-likelihood of the observation under the ensemble.

The per-time score is

DSt=logdetΣt+(μtyt)Σt1(μtyt)\mathrm{DS}_t = \log\det \Sigma_t + (\mu_t - y_t)^{\top}\, \Sigma_t^{-1}\, (\mu_t - y_t)

where Σt\Sigma_t is the unbiased (ddof=1) sample covariance of the ensemble at time tt. Requires S3S \geq 3 for Σt\Sigma_t to be a.s. full-rank on R2\mathbb{R}^2; for S2S \leq 2 the score is undefined (singular covariance).

Parameters
  • forecast (Float[jaxlib._jax.Array, 'S T 2'])Ensemble forecast, shape (S, T, 2), with S >= 3.

  • obs (Float[jaxlib._jax.Array, 'T 2'])Observed trajectory, shape (T, 2).

  • reduce (Literal['last', 'sum'] | None)See squared_error().

  • weights (Float[jaxlib._jax.Array, 'T'] | None)See squared_error().

Returns

Per-time score of shape (T,) or a scalar, per reduce.

Return type

Float[jaxlib._jax.Array, ‘T’] | Float[jaxlib._jax.Array, ‘’]

energy_score

pastax.score.energy_score(forecast, obs, *, kernel=l2_distance, alpha=1.0, reduce=None, weights=None)

Energy score (Pic et al. 2025, Eq. 12) — unbiased Monte Carlo estimator.

ESt=1Ssd ⁣(Xt(s),yt)α12S(S1)ssd ⁣(Xt(s),Xt(s))α\mathrm{ES}_t = \frac{1}{S} \sum_s d\!\left(X_t^{(s)}, y_t\right)^{\alpha} - \frac{1}{2 S (S-1)} \sum_{s \neq s'} d\!\left(X_t^{(s)}, X_t^{(s')}\right)^{\alpha}

The pairwise term is computed as a full (S, S) mean (including the zero diagonal) multiplied by S/(S-1), which recovers the unbiased off-diagonal estimator exactly. Strictly proper for the L2 kernel and α(0,2)\alpha \in (0, 2); propriety with other kernels is not guaranteed.

Parameters
  • forecast (Float[jaxlib._jax.Array, 'S T 2'])Ensemble forecast, shape (S, T, 2), with S >= 2.

  • obs (Float[jaxlib._jax.Array, 'T 2'])Observed trajectory, shape (T, 2).

  • kernel (Callable[[Float[jaxlib._jax.Array, '... 2'], Float[jaxlib._jax.Array, '... 2']], Float[jaxlib._jax.Array, '...']])Broadcasting distance kernel. Defaults to l2_distance().

  • alpha (float)Distance exponent (typically in (0, 2)). Default 1.0.

  • reduce (Literal['last', 'sum'] | None)See squared_error().

  • weights (Float[jaxlib._jax.Array, 'T'] | None)See squared_error().

Returns

Per-time score of shape (T,) or a scalar, per reduce.

Return type

Float[jaxlib._jax.Array, ‘T’] | Float[jaxlib._jax.Array, ‘’]

variogram_score

pastax.score.variogram_score(forecast, obs, *, p=2.0, component_weights=None, reduce=None, weights=None)

Variogram score of order p (Pic et al. 2025, Eq. 13).

VSt=i,jwij(EF ⁣[Xt,iXt,jp]yt,iyt,jp)2\mathrm{VS}_t = \sum_{i,j} w_{ij} \left( \mathbb{E}_F\!\left[\,|X_{t,i} - X_{t,j}|^{p}\right] - |y_{t,i} - y_{t,j}|^{p}\right)^2

Sums over both component pairs (i,j)(i, j); with the default component_weights = 1 - I, the diagonal contribution (zero) is masked out and the off-diagonal pair is counted twice (symmetric formulation).

Parameters
  • forecast (Float[jaxlib._jax.Array, 'S T 2'])Ensemble forecast, shape (S, T, 2).

  • obs (Float[jaxlib._jax.Array, 'T 2'])Observed trajectory, shape (T, 2).

  • p (float)Variogram order. Default 2.0.

  • component_weights (Float[jaxlib._jax.Array, '2 2'] | None)(2, 2) non-negative weight matrix. Defaults to ones((2, 2)) - eye(2).

  • reduce (Literal['last', 'sum'] | None)See squared_error().

  • weights (Float[jaxlib._jax.Array, 'T'] | None)See squared_error().

Returns

Per-time score of shape (T,) or a scalar, per reduce.

Return type

Float[jaxlib._jax.Array, ‘T’] | Float[jaxlib._jax.Array, ‘’]