Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

pastax.forcing

Forcing field representation and loading from xarray datasets or plain arrays.

Field

class pastax.forcing.Field(values, grid, stagger='center', mask=None)

Bases: Module

A single scalar forcing field on a (time, lat, lon) rectilinear grid.

A Field carries no coordinates of its own: it holds a reference to its parent Grid and reads the coordinates appropriate to its stagger role from it (xarray-like — the grid is the shared source of coordinates). The coordinate attributes (t_coords, lat_coords, lon_coords, lon_period) are available as read-only properties that delegate to the grid.

Parameters
  • values (Float[jaxlib._jax.Array, 'time lat lon'])

  • grid (Grid)

  • stagger (Literal['center', 'u_face', 'v_face'])

  • mask (Bool[jaxlib._jax.Array, 'lat lon'] | None)

values

Field values, shape (time, lat, lon).

Type

jaxtyping.Float[jaxlib._jax.Array, ‘time lat lon’]

grid

The parent Grid; the single source of coordinates for this field, indexed by stagger.

Type

pastax.grid.Grid

stagger

Position of this field on the parent grid. "center" (default) is the A-grid / tracer position; "u_face" and "v_face" mark the eastern and northern velocity faces of a NEMO-convention Arakawa C-grid. The grid serves the coordinates for the stagger position (so a stagger="u_face" field reads the grid’s half-cell-shifted U-face longitudes); Field.interp itself is the same bilinear scheme regardless.

Type

Literal[‘center’, ‘u_face’, ‘v_face’]

mask

Optional 2-D boolean land mask aligned with (lat, lon); True marks a land cell, False marks ocean. Assumed time-invariant (wet-and-dry is out of scope). None (default) means no land logic — Field.interp is plain bilinear. When a mask is present, Field.interp switches to inverse-distance partial-cell weighting that consults it to drop land corners.

Type

jaxtyping.Bool[jaxlib._jax.Array, ‘lat lon’] | None

values: Float[jaxlib._jax.Array, 'time lat lon']
grid: Grid
stagger: Literal['center', 'u_face', 'v_face'] = 'center'
mask: Bool[jaxlib._jax.Array, 'lat lon'] | None = None
property t_coords: Float[jaxlib._jax.Array, 'time']

1-D time coordinates in seconds (from the parent grid).

property lat_coords: Float[jaxlib._jax.Array, 'lat']

1-D latitude coordinates in degrees for this field’s stagger role.

property lon_coords: Float[jaxlib._jax.Array, 'lon']

1-D longitude coordinates in degrees for this field’s stagger role.

property lon_period: float | None

Longitude period for this field’s stagger role (None on faces).

classmethod standalone(values, t_coords, lat_coords, lon_coords, lon_period=None, stagger='center', mask=None)

Build a self-contained Field backed by a private one-field grid.

Convenience for constructing a Field outside a Dataset (e.g. in tests): the given coordinates are stored on a private Grid in the slot matching stagger, and the returned field reads them back through the usual grid-backed properties.

Parameters
  • values (Float[jaxlib._jax.Array, 'time lat lon'])

  • t_coords (Float[jaxlib._jax.Array, 'time'])

  • lat_coords (Float[jaxlib._jax.Array, 'lat'])

  • lon_coords (Float[jaxlib._jax.Array, 'lon'])

  • lon_period (float | None)

  • stagger (Literal['center', 'u_face', 'v_face'])

  • mask (Bool[jaxlib._jax.Array, 'lat lon'] | None)

Return type

Field

interp(t, lat, lon)

Trilinearly interpolate the field at a single (t, lat, lon) point.

Parameters
  • t (Float[jaxlib._jax.Array, ''])Query time in seconds.

  • lat (Float[jaxlib._jax.Array, ''])Query latitude in degrees.

  • lon (Float[jaxlib._jax.Array, ''])Query longitude in degrees.

Returns

Interpolated scalar value at the query point. Outside the grid the interpolation extrapolates linearly (clamping to grid boundaries beyond one cell). When lon_period is set, longitude wraps instead of extrapolating. When self.mask is set, coastal cells use inverse-distance partial-cell weighting and fully land-bound cells return 0 (see pastax.interpolation.bilinear_interp_2d()).

Return type

Float[jaxlib._jax.Array, ‘’]

neighborhood(t, lat, lon, t_window=1, lat_window=1, lon_window=1)

Extract a window of raw grid values centred on the nearest grid point.

Parameters
  • t (Float[jaxlib._jax.Array, ''])Query time in seconds.

  • lat (Float[jaxlib._jax.Array, ''])Query latitude in degrees.

  • lon (Float[jaxlib._jax.Array, ''])Query longitude in degrees.

  • t_window (int)Half-width along the time axis (window size = 2*t_window+1).

  • lat_window (int)Half-width along the latitude axis.

  • lon_window (int)Half-width along the longitude axis.

Returns

Array of shape (2*t_window+1, 2*lat_window+1, 2*lon_window+1). Time and latitude windows are clamped to the grid boundary near the edges. The longitude window wraps modulo lon_period when that attribute is set, otherwise it is clamped like the others.

Return type

Float[jaxlib._jax.Array, ‘wt wlat wlon’]

Dataset

class pastax.forcing.Dataset(fields, grid=None)

Bases: Module

Collection of named Field instances sharing a common grid.

Parameters
fields

Mapping {field_name: Field}. For A-grid datasets every field lives at cell centres; for C-grid datasets velocity fields live on their respective faces (see Field.stagger).

Type

dict[str, pastax.forcing.Field]

grid

The shared Grid owning the coordinates of every field (centre coordinates plus, for C-grids, the staggered U/V-face coordinates) and the stagger type of the underlying ocean grid. All loaders populate it — A-grid datasets carry a stagger_type="A" grid, C-grid datasets a stagger_type="C" grid. None is accepted when constructing a Dataset directly.

Type

pastax.grid.Grid | None

fields: dict[str, Field]
grid: Grid | None = None
neighborhood(t, lat, lon, t_window=1, lat_window=1, lon_window=1)

Extract a neighbourhood patch from every field at one query point.

Equivalent to calling Field.neighborhood() on every field with the same query and window arguments. Useful for SDE terms that need local spatial gradients (e.g. Smagorinsky-style diffusion).

Parameters
  • t (Float[jaxlib._jax.Array, ''])Query time in seconds.

  • lat (Float[jaxlib._jax.Array, ''])Query latitude in degrees.

  • lon (Float[jaxlib._jax.Array, ''])Query longitude in degrees.

  • t_window (int)Half-width along the time axis (window size = 2*t_window+1).

  • lat_window (int)Half-width along the latitude axis.

  • lon_window (int)Half-width along the longitude axis.

Returns

Mapping {field_name: array} where each array has shape (2*t_window+1, 2*lat_window+1, 2*lon_window+1).

Return type

dict[str, Float[jaxlib._jax.Array, ‘wt wlat wlon’]]

velocity_interp(t, lat, lon, *, scheme='default', u_name='u', v_name='v', slip_a=0.5, slip_b=0.5)

Interpolate the (U, V) velocity vector at a single point.

Returns [v_value, u_value] so the result can be used directly as the [dlat/dt, dlon/dt][\mathrm{d}lat/\mathrm{d}t,\ \mathrm{d}lon/\mathrm{d}t] part of a solver term (after the usual metres-to-degrees conversion if applicable).

Parameters
  • t (Float[jaxlib._jax.Array, ''])Query time in seconds.

  • lat (Float[jaxlib._jax.Array, ''])Query latitude in degrees.

  • lon (Float[jaxlib._jax.Array, ''])Query longitude in degrees.

  • scheme (Literal['default', 'partialslip'])

    Coastal interpolation scheme.

    • "default" (default) — composes per-field Field.interp() for V and U. Each field uses its own scheme as configured by its mask: bilinear with inverse-distance partial-cell weighting when a mask is present, plain bilinear otherwise.

    • "partialslip" — A-grid only. Reads U and V together with their joint land mask (the AND of both fields’ masks) and applies a wall-slip correction whenever a full cell edge is land: U is rescaled by (slip_a+slip_bwl)(\mathrm{slip\_a} + \mathrm{slip\_b}\,w_l) near a latitudinal coast and V by (slip_a+slip_bwj)(\mathrm{slip\_a} + \mathrm{slip\_b}\,w_j) near a longitudinal coast. The default a=b=0.5a = b = 0.5 gives a half-slip wall; a=1, b=0a = 1,\ b = 0 recovers full free-slip. Requires both U and V to carry a mask; raises ValueError otherwise. Raises NotImplementedError on Arakawa C-grid datasets.

  • u_name (str)Name of the U-component Field in self.fields.

  • v_name (str)Name of the V-component Field in self.fields.

  • slip_a (float)Wall slip coefficient (partial-slip only).

  • slip_b (float)Wall slip gradient coefficient (partial-slip only).

Returns

[v, u] velocity vector of shape (2,).

Return type

Float[jaxlib._jax.Array, ‘2’]

static from_arrays(fields, t, lat, lon, dtype=jnp.float32, lon_period=None, masks=None)

Build a Dataset from numpy or JAX arrays.

Parameters
  • fields (dict[str, Array])Mapping {field_name: array of shape (time, lat, lon)}.

  • t (Array)1-D time coordinate array. Either equally-spaced numeric values (seconds since an arbitrary reference) or a NumPy datetime64 array (any unit); the latter is auto-converted to int seconds since the Unix epoch.

  • lat (Array)1-D latitude coordinate array (degrees), equally spaced.

  • lon (Array)1-D longitude coordinate array (degrees), equally spaced.

  • dtype (DTypeLike)JAX dtype for all arrays (default float32).

  • lon_period (float | None)If set (e.g. 360.0), all fields are constructed with periodic longitude wrapping. The grid must span exactly one period.

  • masks (dict[str, Array] | None)Optional {field_name: 2-D bool array of shape (lat, lon)} land masks. True marks a land cell. When a field appears in masks, that mask is used. Otherwise a mask is inferred from NaN locations in the values array (collapsed across the time axis). Fields with neither user-supplied nor inferred NaN entries carry mask=None — interp behaviour is then bit-exact identical to the legacy mask-less path. NaN values in the input are always replaced with 0 in the stored values so no NaN can leak into interpolation.

Returns

Dataset with all fields on the given grid.

Return type

Dataset

static from_xarray(ds, fields, coordinates, dtype=jnp.float32, lon_period=None, masks=None)

Load a Dataset from an xarray Dataset (zarr or netCDF backed).

Parameters
  • ds (xr.Dataset)Source xarray Dataset.

  • fields (dict[str, str])Mapping {internal_name: xarray_variable_name}.

  • coordinates (dict[str, str])Mapping with keys “time”, “lat”, “lon” → xarray coord names.

  • dtype (DTypeLike)JAX dtype for all arrays (default float32).

  • lon_period (float | None)If set (e.g. 360.0), all fields are constructed with periodic longitude wrapping. The grid must span exactly one period.

  • masks (dict[str, Array] | None)Optional land masks keyed by internal field name; see from_arrays() for semantics. If omitted, masks are inferred from NaN — which matches the CMEMS / CF _FillValue convention.

Returns

Dataset with all fields loaded into host memory as JAX arrays.

Return type

Dataset

static from_arrays_cgrid(t, center_lat, center_lon, vectors, tracers=None, *, u_lat=None, u_lon=None, v_lat=None, v_lon=None, dtype=jnp.float32, lon_period=None, masks=None)

Build a Dataset on a NEMO-convention Arakawa C-grid.

The centre grid (center_lat, center_lon) carries any tracer fields. Each vector field has its U component on the east faces of the centre cells (one fewer longitude column) and its V component on the north faces (one fewer latitude row). Several vector fields can share the same C-grid — e.g. surface current and 10-m wind, or geostrophic / Ekman / Stokes velocity components — by registering additional entries in vectors. When the staggered coordinate arrays are omitted they are auto-derived from the centre grid as half-cell shifts (see Grid.u_face_coords() / Grid.v_face_coords()) and shared by every registered vector.

Parameters
  • t (Array)1-D time coordinates (seconds or NumPy datetime64).

  • center_lat (Array)1-D centre latitudes (degrees), equally spaced.

  • center_lon (Array)1-D centre longitudes (degrees), equally spaced.

  • vectors (dict[str, dict[Literal['u', 'v'], tuple[str, Array]]])Mapping {group_name: {"u": (field_name, u_array), "v": (field_name, v_array)}}. The outer key is a free-form label for the vector pair (e.g. "current", "wind", "geostrophic") and is used only in error messages. The inner (field_name, array) tuples give the names under which each component is registered in Dataset.fields (and how velocity_interp() finds them via u_name / v_name) and the corresponding values. U arrays have shape (time, nlat, nlon - 1); V arrays have shape (time, nlat - 1, nlon). At least one vector must be supplied; field names must be unique across all vectors and tracers.

  • tracers (dict[str, Array] | None)Optional mapping {name: array of shape (time, nlat, nlon)} for additional fields at cell centres.

  • u_lat (Array | None)Override for U latitudes (defaults to center_lat). Shared by every registered U field.

  • u_lon (Array | None)Override for U longitudes (defaults to centre lons shifted east by half a cell, length nlon - 1). Shared by every registered U field.

  • v_lat (Array | None)Override for V latitudes (defaults to centre lats shifted north by half a cell, length nlat - 1). Shared by every registered V field.

  • v_lon (Array | None)Override for V longitudes (defaults to center_lon). Shared by every registered V field.

  • dtype (DTypeLike)JAX dtype for all arrays (default float32).

  • lon_period (float | None)If set (e.g. 360.0), the centre grid is treated as periodic in longitude. Tracer fields receive lon_period; U/V faces do not (their coordinate arrays no longer span a full period, so periodic wrapping would be ill-defined at first order).

  • masks (dict[str, Array] | None)Optional land masks keyed by the field names declared in vectors and tracers. Each mask is a 2-D bool array; the expected shape per field is (nlat, nlon - 1) for a U-face field, (nlat - 1, nlon) for a V-face field, and (nlat, nlon) for a tracer. When a field is absent from masks, a mask is inferred from NaN locations in that field’s values array. NaN values are always replaced with 0 in the stored values.

Returns

Dataset with one Field(stagger="u_face") and one Field(stagger="v_face") per registered vector (plus any tracers) and a C-grid Grid metadata object.

Return type

Dataset

static from_xarray_cgrid(ds, *, vectors, coordinates, tracers=None, staggered_coordinates=None, dtype=jnp.float32, lon_period=None, masks=None)

Load a C-grid Dataset from an xarray Dataset.

Centre coordinates (used for time and tracer fields) come from coordinates; staggered U/V coordinates are auto-derived from the centre grid as half-cell shifts unless overridden via staggered_coordinates. Multiple vector fields living on the same C-grid (e.g. surface current and 10-m wind) are declared as separate entries in vectors.

Parameters
  • ds (xr.Dataset)Source xarray Dataset.

  • vectors (dict[str, dict[Literal['u', 'v'], tuple[str, str]]])Mapping {group_name: {"u": (field_name, xarray_var_name), "v": (field_name, xarray_var_name)}}. The outer key is a free-form label for the vector pair (e.g. "current", "wind"). Each inner (field_name, xarray_var_name) tuple says which xarray variable holds the values and under which name to register the resulting Field in Dataset.fields. U variables have shape (time, nlat, nlon - 1); V variables have shape (time, nlat - 1, nlon).

  • coordinates (dict[str, str])Mapping with keys "time", "lat", "lon" → xarray coord names for the centre grid.

  • tracers (dict[str, str] | None)Optional {internal_name: xarray_variable_name} for extra centre-grid fields.

  • staggered_coordinates (dict[str, str] | None)Optional override mapping with any subset of keys "u_lat", "u_lon", "v_lat", "v_lon" → xarray coord names. Unspecified keys are auto-derived. The overrides are shared by every registered vector.

  • dtype (DTypeLike)JAX dtype for all arrays (default float32).

  • lon_period (float | None)Forwarded to from_arrays_cgrid().

  • masks (dict[str, Array] | None)Forwarded to from_arrays_cgrid(). Keys must match the field names declared in vectors and tracers.

Returns

Dataset with C-grid stagger and Grid metadata.

Return type

Dataset