Forcing field representation and loading from xarray datasets or plain arrays.
Field¶
- class pastax.forcing.Field(values, grid, stagger='center', mask=None)
Bases:
ModuleA single scalar forcing field on a (time, lat, lon) rectilinear grid.
A
Fieldcarries no coordinates of its own: it holds a reference to its parentGridand reads the coordinates appropriate to its stagger role from it (xarray-like — the grid is the shared source of coordinates). The coordinate attributes (t_coords,lat_coords,lon_coords,lon_period) are available as read-only properties that delegate to the grid.- Parameters
values(Float[jaxlib._jax.Array,'time lat lon'])grid(Grid)stagger(Literal['center','u_face','v_face'])mask(Bool[jaxlib._jax.Array,'lat lon']|None)
- values
Field values, shape
(time, lat, lon).- Type
jaxtyping.Float[jaxlib._jax.Array, ‘time lat lon’]
- grid
The parent
Grid; the single source of coordinates for this field, indexed bystagger.- Type
- stagger
Position of this field on the parent grid.
"center"(default) is the A-grid / tracer position;"u_face"and"v_face"mark the eastern and northern velocity faces of a NEMO-convention Arakawa C-grid. The grid serves the coordinates for the stagger position (so astagger="u_face"field reads the grid’s half-cell-shifted U-face longitudes);Field.interpitself is the same bilinear scheme regardless.- Type
Literal[‘center’, ‘u_face’, ‘v_face’]
- mask
Optional 2-D boolean land mask aligned with
(lat, lon);Truemarks a land cell,Falsemarks ocean. Assumed time-invariant (wet-and-dry is out of scope).None(default) means no land logic —Field.interpis plain bilinear. When a mask is present,Field.interpswitches to inverse-distance partial-cell weighting that consults it to drop land corners.- Type
jaxtyping.Bool[jaxlib._jax.Array, ‘lat lon’] | None
- values: Float[jaxlib._jax.Array, 'time lat lon']
- grid: Grid
- stagger: Literal['center', 'u_face', 'v_face'] = 'center'
- mask: Bool[jaxlib._jax.Array, 'lat lon'] | None = None
- property t_coords: Float[jaxlib._jax.Array, 'time']
1-D time coordinates in seconds (from the parent grid).
- property lat_coords: Float[jaxlib._jax.Array, 'lat']
1-D latitude coordinates in degrees for this field’s stagger role.
- property lon_coords: Float[jaxlib._jax.Array, 'lon']
1-D longitude coordinates in degrees for this field’s stagger role.
- property lon_period: float | None
Longitude period for this field’s stagger role (
Noneon faces).
- classmethod standalone(values, t_coords, lat_coords, lon_coords, lon_period=None, stagger='center', mask=None)
Build a self-contained
Fieldbacked by a private one-field grid.Convenience for constructing a
Fieldoutside aDataset(e.g. in tests): the given coordinates are stored on a privateGridin the slot matchingstagger, and the returned field reads them back through the usual grid-backed properties.- Parameters
values(Float[jaxlib._jax.Array,'time lat lon'])t_coords(Float[jaxlib._jax.Array,'time'])lat_coords(Float[jaxlib._jax.Array,'lat'])lon_coords(Float[jaxlib._jax.Array,'lon'])lon_period(float|None)stagger(Literal['center','u_face','v_face'])mask(Bool[jaxlib._jax.Array,'lat lon']|None)
- Return type
- interp(t, lat, lon)
Trilinearly interpolate the field at a single
(t, lat, lon)point.- Parameters
t(Float[jaxlib._jax.Array,'']) – Query time in seconds.lat(Float[jaxlib._jax.Array,'']) – Query latitude in degrees.lon(Float[jaxlib._jax.Array,'']) – Query longitude in degrees.
- Returns
Interpolated scalar value at the query point. Outside the grid the interpolation extrapolates linearly (clamping to grid boundaries beyond one cell). When
lon_periodis set, longitude wraps instead of extrapolating. Whenself.maskis set, coastal cells use inverse-distance partial-cell weighting and fully land-bound cells return0(seepastax.interpolation.bilinear_interp_2d()).- Return type
Float[jaxlib._jax.Array, ‘’]
- neighborhood(t, lat, lon, t_window=1, lat_window=1, lon_window=1)
Extract a window of raw grid values centred on the nearest grid point.
- Parameters
t(Float[jaxlib._jax.Array,'']) – Query time in seconds.lat(Float[jaxlib._jax.Array,'']) – Query latitude in degrees.lon(Float[jaxlib._jax.Array,'']) – Query longitude in degrees.t_window(int) – Half-width along the time axis (window size = 2*t_window+1).lat_window(int) – Half-width along the latitude axis.lon_window(int) – Half-width along the longitude axis.
- Returns
Array of shape (2*t_window+1, 2*lat_window+1, 2*lon_window+1). Time and latitude windows are clamped to the grid boundary near the edges. The longitude window wraps modulo
lon_periodwhen that attribute is set, otherwise it is clamped like the others.- Return type
Float[jaxlib._jax.Array, ‘wt wlat wlon’]
Dataset¶
- class pastax.forcing.Dataset(fields, grid=None)
Bases:
ModuleCollection of named
Fieldinstances sharing a common grid.- fields
Mapping
{field_name: Field}. For A-grid datasets every field lives at cell centres; for C-grid datasets velocity fields live on their respective faces (seeField.stagger).- Type
dict[str, pastax.forcing.Field]
- grid
The shared
Gridowning the coordinates of every field (centre coordinates plus, for C-grids, the staggered U/V-face coordinates) and the stagger type of the underlying ocean grid. All loaders populate it — A-grid datasets carry astagger_type="A"grid, C-grid datasets astagger_type="C"grid.Noneis accepted when constructing aDatasetdirectly.- Type
pastax.grid.Grid | None
- fields: dict[str, Field]
- grid: Grid | None = None
- neighborhood(t, lat, lon, t_window=1, lat_window=1, lon_window=1)
Extract a neighbourhood patch from every field at one query point.
Equivalent to calling
Field.neighborhood()on every field with the same query and window arguments. Useful for SDE terms that need local spatial gradients (e.g. Smagorinsky-style diffusion).- Parameters
t(Float[jaxlib._jax.Array,'']) – Query time in seconds.lat(Float[jaxlib._jax.Array,'']) – Query latitude in degrees.lon(Float[jaxlib._jax.Array,'']) – Query longitude in degrees.t_window(int) – Half-width along the time axis (window size =2*t_window+1).lat_window(int) – Half-width along the latitude axis.lon_window(int) – Half-width along the longitude axis.
- Returns
Mapping
{field_name: array}where each array has shape(2*t_window+1, 2*lat_window+1, 2*lon_window+1).- Return type
dict[str, Float[jaxlib._jax.Array, ‘wt wlat wlon’]]
- velocity_interp(t, lat, lon, *, scheme='default', u_name='u', v_name='v', slip_a=0.5, slip_b=0.5)
Interpolate the
(U, V)velocity vector at a single point.Returns
[v_value, u_value]so the result can be used directly as the part of a solver term (after the usual metres-to-degrees conversion if applicable).- Parameters
t(Float[jaxlib._jax.Array,'']) – Query time in seconds.lat(Float[jaxlib._jax.Array,'']) – Query latitude in degrees.lon(Float[jaxlib._jax.Array,'']) – Query longitude in degrees.scheme(Literal['default','partialslip']) –Coastal interpolation scheme.
"default"(default) — composes per-fieldField.interp()forVandU. Each field uses its own scheme as configured by its mask: bilinear with inverse-distance partial-cell weighting when a mask is present, plain bilinear otherwise."partialslip"— A-grid only. ReadsUandVtogether with their joint land mask (the AND of both fields’ masks) and applies a wall-slip correction whenever a full cell edge is land:Uis rescaled by near a latitudinal coast andVby near a longitudinal coast. The default gives a half-slip wall; recovers full free-slip. Requires both U and V to carry a mask; raisesValueErrorotherwise. RaisesNotImplementedErroron Arakawa C-grid datasets.
u_name(str) – Name of the U-component Field inself.fields.v_name(str) – Name of the V-component Field inself.fields.slip_a(float) – Wall slip coefficient (partial-slip only).slip_b(float) – Wall slip gradient coefficient (partial-slip only).
- Returns
[v, u]velocity vector of shape(2,).- Return type
Float[jaxlib._jax.Array, ‘2’]
- static from_arrays(fields, t, lat, lon, dtype=jnp.float32, lon_period=None, masks=None)
Build a Dataset from numpy or JAX arrays.
- Parameters
fields(dict[str,Array]) – Mapping {field_name: array of shape (time, lat, lon)}.t(Array) – 1-D time coordinate array. Either equally-spaced numeric values (seconds since an arbitrary reference) or a NumPydatetime64array (any unit); the latter is auto-converted to int seconds since the Unix epoch.lat(Array) – 1-D latitude coordinate array (degrees), equally spaced.lon(Array) – 1-D longitude coordinate array (degrees), equally spaced.dtype(DTypeLike) – JAX dtype for all arrays (default float32).lon_period(float|None) – If set (e.g.360.0), all fields are constructed with periodic longitude wrapping. The grid must span exactly one period.masks(dict[str,Array]|None) – Optional{field_name: 2-D bool array of shape (lat, lon)}land masks.Truemarks a land cell. When a field appears inmasks, that mask is used. Otherwise a mask is inferred from NaN locations in the values array (collapsed across the time axis). Fields with neither user-supplied nor inferred NaN entries carrymask=None— interp behaviour is then bit-exact identical to the legacy mask-less path. NaN values in the input are always replaced with 0 in the storedvaluesso no NaN can leak into interpolation.
- Returns
Dataset with all fields on the given grid.
- Return type
- static from_xarray(ds, fields, coordinates, dtype=jnp.float32, lon_period=None, masks=None)
Load a Dataset from an xarray Dataset (zarr or netCDF backed).
- Parameters
ds(xr.Dataset) – Source xarray Dataset.fields(dict[str,str]) – Mapping {internal_name: xarray_variable_name}.coordinates(dict[str,str]) – Mapping with keys “time”, “lat”, “lon” → xarray coord names.dtype(DTypeLike) – JAX dtype for all arrays (default float32).lon_period(float|None) – If set (e.g.360.0), all fields are constructed with periodic longitude wrapping. The grid must span exactly one period.masks(dict[str,Array]|None) – Optional land masks keyed by internal field name; seefrom_arrays()for semantics. If omitted, masks are inferred from NaN — which matches the CMEMS / CF_FillValueconvention.
- Returns
Dataset with all fields loaded into host memory as JAX arrays.
- Return type
- static from_arrays_cgrid(t, center_lat, center_lon, vectors, tracers=None, *, u_lat=None, u_lon=None, v_lat=None, v_lon=None, dtype=jnp.float32, lon_period=None, masks=None)
Build a Dataset on a NEMO-convention Arakawa C-grid.
The centre grid
(center_lat, center_lon)carries any tracer fields. Each vector field has its U component on the east faces of the centre cells (one fewer longitude column) and its V component on the north faces (one fewer latitude row). Several vector fields can share the same C-grid — e.g. surface current and 10-m wind, or geostrophic / Ekman / Stokes velocity components — by registering additional entries invectors. When the staggered coordinate arrays are omitted they are auto-derived from the centre grid as half-cell shifts (seeGrid.u_face_coords()/Grid.v_face_coords()) and shared by every registered vector.- Parameters
t(Array) – 1-D time coordinates (seconds or NumPydatetime64).center_lat(Array) – 1-D centre latitudes (degrees), equally spaced.center_lon(Array) – 1-D centre longitudes (degrees), equally spaced.vectors(dict[str,dict[Literal['u','v'],tuple[str,Array]]]) – Mapping{group_name: {"u": (field_name, u_array), "v": (field_name, v_array)}}. The outer key is a free-form label for the vector pair (e.g."current","wind","geostrophic") and is used only in error messages. The inner(field_name, array)tuples give the names under which each component is registered inDataset.fields(and howvelocity_interp()finds them viau_name/v_name) and the corresponding values. U arrays have shape(time, nlat, nlon - 1); V arrays have shape(time, nlat - 1, nlon). At least one vector must be supplied; field names must be unique across all vectors and tracers.tracers(dict[str,Array]|None) – Optional mapping{name: array of shape (time, nlat, nlon)}for additional fields at cell centres.u_lat(Array|None) – Override for U latitudes (defaults tocenter_lat). Shared by every registered U field.u_lon(Array|None) – Override for U longitudes (defaults to centre lons shifted east by half a cell, lengthnlon - 1). Shared by every registered U field.v_lat(Array|None) – Override for V latitudes (defaults to centre lats shifted north by half a cell, lengthnlat - 1). Shared by every registered V field.v_lon(Array|None) – Override for V longitudes (defaults tocenter_lon). Shared by every registered V field.dtype(DTypeLike) – JAX dtype for all arrays (default float32).lon_period(float|None) – If set (e.g.360.0), the centre grid is treated as periodic in longitude. Tracer fields receivelon_period; U/V faces do not (their coordinate arrays no longer span a full period, so periodic wrapping would be ill-defined at first order).masks(dict[str,Array]|None) – Optional land masks keyed by the field names declared invectorsandtracers. Each mask is a 2-D bool array; the expected shape per field is(nlat, nlon - 1)for a U-face field,(nlat - 1, nlon)for a V-face field, and(nlat, nlon)for a tracer. When a field is absent frommasks, a mask is inferred from NaN locations in that field’s values array. NaN values are always replaced with 0 in the storedvalues.
- Returns
Dataset with one
Field(stagger="u_face")and oneField(stagger="v_face")per registered vector (plus any tracers) and a C-gridGridmetadata object.- Return type
- static from_xarray_cgrid(ds, *, vectors, coordinates, tracers=None, staggered_coordinates=None, dtype=jnp.float32, lon_period=None, masks=None)
Load a C-grid Dataset from an xarray Dataset.
Centre coordinates (used for time and tracer fields) come from
coordinates; staggered U/V coordinates are auto-derived from the centre grid as half-cell shifts unless overridden viastaggered_coordinates. Multiple vector fields living on the same C-grid (e.g. surface current and 10-m wind) are declared as separate entries invectors.- Parameters
ds(xr.Dataset) – Source xarray Dataset.vectors(dict[str,dict[Literal['u','v'],tuple[str,str]]]) – Mapping{group_name: {"u": (field_name, xarray_var_name), "v": (field_name, xarray_var_name)}}. The outer key is a free-form label for the vector pair (e.g."current","wind"). Each inner(field_name, xarray_var_name)tuple says which xarray variable holds the values and under which name to register the resulting Field inDataset.fields. U variables have shape(time, nlat, nlon - 1); V variables have shape(time, nlat - 1, nlon).coordinates(dict[str,str]) – Mapping with keys"time","lat","lon"→ xarray coord names for the centre grid.tracers(dict[str,str]|None) – Optional{internal_name: xarray_variable_name}for extra centre-grid fields.staggered_coordinates(dict[str,str]|None) – Optional override mapping with any subset of keys"u_lat","u_lon","v_lat","v_lon"→ xarray coord names. Unspecified keys are auto-derived. The overrides are shared by every registered vector.dtype(DTypeLike) – JAX dtype for all arrays (default float32).lon_period(float|None) – Forwarded tofrom_arrays_cgrid().masks(dict[str,Array]|None) – Forwarded tofrom_arrays_cgrid(). Keys must match the field names declared invectorsandtracers.
- Returns
Dataset with C-grid stagger and
Gridmetadata.- Return type