Skip to content

Commit 7670a5a

Browse files
Removing default mesh from XGrid.from_dataset()
1 parent 418b105 commit 7670a5a

8 files changed

Lines changed: 38 additions & 37 deletions

File tree

src/parcels/_core/xgrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(self, grid: xgcm.Grid, mesh):
124124
self._ds = ds
125125

126126
@classmethod
127-
def from_dataset(cls, ds: xr.Dataset, mesh="flat", xgcm_kwargs=None):
127+
def from_dataset(cls, ds: xr.Dataset, mesh, xgcm_kwargs=None):
128128
"""WARNING: unstable API, subject to change in future versions.""" # TODO v4: make private or remove warning on v4 release
129129
if xgcm_kwargs is None:
130130
xgcm_kwargs = {}

tests/test_advection.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_advection_zonal_periodic():
8484
halo.XG.values = ds.XG.values[1] + 2
8585
ds = xr.concat([ds, halo], dim="XG")
8686

87-
grid = XGrid.from_dataset(ds)
87+
grid = XGrid.from_dataset(ds, mesh="flat")
8888
U = Field("U", ds["U"], grid, interp_method=XLinear)
8989
V = Field("V", ds["V"], grid, interp_method=XLinear)
9090
UV = VectorField("UV", U, V)
@@ -103,7 +103,7 @@ def test_horizontal_advection_in_3D_flow(npart=10):
103103
"""Flat 2D zonal flow that increases linearly with z from 0 m/s to 1 m/s."""
104104
ds = simple_UV_dataset(mesh="flat")
105105
ds["U"].data[:] = 1.0
106-
grid = XGrid.from_dataset(ds)
106+
grid = XGrid.from_dataset(ds, mesh="flat")
107107
U = Field("U", ds["U"], grid, interp_method=XLinear)
108108
U.data[:, 0, :, :] = 0.0 # Set U to 0 at the surface
109109
V = Field("V", ds["V"], grid, interp_method=XLinear)
@@ -121,7 +121,7 @@ def test_horizontal_advection_in_3D_flow(npart=10):
121121
@pytest.mark.parametrize("wErrorThroughSurface", [True, False])
122122
def test_advection_3D_outofbounds(direction, wErrorThroughSurface):
123123
ds = simple_UV_dataset(mesh="flat")
124-
grid = XGrid.from_dataset(ds)
124+
grid = XGrid.from_dataset(ds, mesh="flat")
125125
U = Field("U", ds["U"], grid, interp_method=XLinear)
126126
U.data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds)
127127
V = Field("V", ds["V"], grid, interp_method=XLinear)
@@ -202,7 +202,7 @@ def test_length1dimensions(u, v, w): # TODO: Refactor this test to be more read
202202
if w:
203203
ds["W"] = (["time", "depth", "YG", "XG"], W)
204204

205-
grid = XGrid.from_dataset(ds)
205+
grid = XGrid.from_dataset(ds, mesh="flat")
206206
U = Field("U", ds["U"], grid, interp_method=XLinear)
207207
V = Field("V", ds["V"], grid, interp_method=XLinear)
208208
fields = [U, V, VectorField("UV", U, V)]
@@ -263,7 +263,7 @@ def test_radialrotation(npart=10):
263263
)
264264
def test_moving_eddy(kernel, rtol):
265265
ds = moving_eddy_dataset()
266-
grid = XGrid.from_dataset(ds)
266+
grid = XGrid.from_dataset(ds, mesh="flat")
267267
U = Field("U", ds["U"], grid, interp_method=XLinear)
268268
V = Field("V", ds["V"], grid, interp_method=XLinear)
269269
if kernel in [AdvectionRK2_3D, AdvectionRK4_3D]:
@@ -315,7 +315,7 @@ def truth_moving(x_0, y_0, t):
315315
)
316316
def test_decaying_moving_eddy(kernel, rtol):
317317
ds = decaying_moving_eddy_dataset()
318-
grid = XGrid.from_dataset(ds)
318+
grid = XGrid.from_dataset(ds, mesh="flat")
319319
U = Field("U", ds["U"], grid, interp_method=XLinear)
320320
V = Field("V", ds["V"], grid, interp_method=XLinear)
321321
UV = VectorField("UV", U, V)
@@ -363,7 +363,7 @@ def truth_moving(x_0, y_0, t):
363363
def test_stommelgyre_fieldset(kernel, rtol, grid_type):
364364
npart = 2
365365
ds = stommel_gyre_dataset(grid_type=grid_type)
366-
grid = XGrid.from_dataset(ds)
366+
grid = XGrid.from_dataset(ds, mesh="flat")
367367
vector_interp_method = None if grid_type == "A" else CGrid_Velocity
368368
U = Field("U", ds["U"], grid, interp_method=XLinear)
369369
V = Field("V", ds["V"], grid, interp_method=XLinear)
@@ -404,7 +404,7 @@ def UpdateP(particles, fieldset): # pragma: no cover
404404
def test_peninsula_fieldset(kernel, rtol, grid_type):
405405
npart = 2
406406
ds = peninsula_dataset(grid_type=grid_type)
407-
grid = XGrid.from_dataset(ds)
407+
grid = XGrid.from_dataset(ds, mesh="flat")
408408
U = Field("U", ds["U"], grid, interp_method=XLinear)
409409
V = Field("V", ds["V"], grid, interp_method=XLinear)
410410
P = Field("P", ds["P"], grid, interp_method=XLinear)

tests/test_field.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def test_field_init_param_types():
1616
data = datasets_structured["ds_2d_left"]
17-
grid = XGrid.from_dataset(data)
17+
grid = XGrid.from_dataset(data, mesh="flat")
1818

1919
with pytest.raises(TypeError, match="Expected a string for variable name, got int instead."):
2020
Field(name=123, data=data["data_g"], grid=grid, interp_method=XLinear)
@@ -47,7 +47,7 @@ def test_field_init_param_types():
4747
[
4848
pytest.param(
4949
ux.UxDataArray(),
50-
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
50+
XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
5151
id="uxdata-grid",
5252
),
5353
pytest.param(
@@ -76,7 +76,7 @@ def test_field_incompatible_combination(data, grid):
7676
[
7777
pytest.param(
7878
datasets_structured["ds_2d_left"]["data_g"],
79-
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
79+
XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
8080
id="ds_2d_left",
8181
), # TODO: Perhaps this test should be expanded to cover more datasets?
8282
],
@@ -107,7 +107,7 @@ def test_field_init_fail_on_float_time_dim():
107107
)
108108

109109
data = ds["data_g"]
110-
grid = XGrid.from_dataset(ds)
110+
grid = XGrid.from_dataset(ds, mesh="flat")
111111
with pytest.raises(
112112
ValueError,
113113
match="Error getting time interval.*. Are you sure that the time dimension on the xarray dataset is stored as timedelta, datetime or cftime datetime objects\?",
@@ -125,7 +125,7 @@ def test_field_init_fail_on_float_time_dim():
125125
[
126126
pytest.param(
127127
datasets_structured["ds_2d_left"]["data_g"],
128-
XGrid.from_dataset(datasets_structured["ds_2d_left"]),
128+
XGrid.from_dataset(datasets_structured["ds_2d_left"], mesh="flat"),
129129
id="ds_2d_left",
130130
),
131131
],
@@ -144,7 +144,7 @@ def test_vectorfield_init_different_time_intervals():
144144

145145
def test_field_invalid_interpolator():
146146
ds = datasets_structured["ds_2d_left"]
147-
grid = XGrid.from_dataset(ds)
147+
grid = XGrid.from_dataset(ds, mesh="flat")
148148

149149
def invalid_interpolator_wrong_signature(particle_positions, grid_positions, invalid):
150150
return 0.0
@@ -161,7 +161,7 @@ def invalid_interpolator_wrong_signature(particle_positions, grid_positions, inv
161161

162162
def test_vectorfield_invalid_interpolator():
163163
ds = datasets_structured["ds_2d_left"]
164-
grid = XGrid.from_dataset(ds)
164+
grid = XGrid.from_dataset(ds, mesh="flat")
165165

166166
def invalid_interpolator_wrong_signature(particle_positions, grid_positions, invalid):
167167
return 0.0

tests/test_index_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
@pytest.fixture
1414
def field_cone():
1515
ds = datasets["2d_left_unrolled_cone"]
16-
grid = XGrid.from_dataset(ds)
16+
grid = XGrid.from_dataset(ds, mesh="flat")
1717
field = Field(
1818
name="test_field",
1919
data=ds["data_g"],

tests/test_interpolation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def field():
5252
"y": (["y"], [0.5, 1.5, 2.5, 3.5], {"axis": "Y"}),
5353
},
5454
)
55-
return Field("U", ds["U"], XGrid.from_dataset(ds), interp_method=XLinear)
55+
return Field("U", ds["U"], XGrid.from_dataset(ds, mesh="flat"), interp_method=XLinear)
5656

5757

5858
@pytest.mark.parametrize(

tests/test_particlefile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
def fieldset() -> FieldSet: # TODO v4: Move into a `conftest.py` file and remove duplicates
3333
"""Fixture to create a FieldSet object for testing."""
3434
ds = datasets["ds_2d_left"]
35-
grid = XGrid.from_dataset(ds)
35+
grid = XGrid.from_dataset(ds, mesh="flat")
3636
U = Field("U", ds["U_A_grid"], grid, XLinear)
3737
V = Field("V", ds["V_A_grid"], grid, XLinear)
3838
UV = VectorField("UV", U, V)
@@ -73,7 +73,7 @@ def test_pfile_array_write_zarr_memorystore(fieldset):
7373
def test_write_fieldset_without_time(tmp_zarrfile):
7474
ds = peninsula_dataset() # DataSet without time
7575
assert "time" not in ds.dims
76-
grid = XGrid.from_dataset(ds)
76+
grid = XGrid.from_dataset(ds, mesh="flat")
7777
fieldset = FieldSet([Field("U", ds["U"], grid, XLinear)])
7878

7979
pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0)

tests/test_spatialhash.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
def test_spatialhash_init():
88
ds = datasets["2d_left_rotated"]
9-
grid = XGrid.from_dataset(ds)
9+
grid = XGrid.from_dataset(ds, mesh="flat")
1010
spatialhash = grid.get_spatial_hash()
1111
assert spatialhash is not None
1212

1313

1414
def test_invalid_positions():
1515
ds = datasets["2d_left_rotated"]
16-
grid = XGrid.from_dataset(ds)
16+
grid = XGrid.from_dataset(ds, mesh="flat")
1717

1818
j, i, coords = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf])
1919
assert np.all(j == -3)
@@ -22,7 +22,7 @@ def test_invalid_positions():
2222

2323
def test_mixed_positions():
2424
ds = datasets["2d_left_rotated"]
25-
grid = XGrid.from_dataset(ds)
25+
grid = XGrid.from_dataset(ds, mesh="flat")
2626
lat = grid.lat.mean()
2727
lon = grid.lon.mean()
2828
y = [lat, np.nan]

tests/test_xgrid.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,19 @@ def test_grid_init_param_types(ds):
5555

5656
@pytest.mark.parametrize("ds, attr, expected", test_cases)
5757
def test_xgrid_properties_ground_truth(ds, attr, expected):
58-
grid = XGrid.from_dataset(ds)
58+
grid = XGrid.from_dataset(ds, mesh="flat")
5959
actual = getattr(grid, attr)
6060
assert_equal(actual, expected)
6161

6262

6363
@pytest.mark.parametrize("ds", [pytest.param(ds, id=key) for key, ds in datasets.items()])
6464
def test_xgrid_from_dataset_on_generic_datasets(ds):
65-
XGrid.from_dataset(ds)
65+
XGrid.from_dataset(ds, mesh="flat")
6666

6767

68+
@pytest.mark.parametrize("ds", [datasets["ds_2d_left"]])
6869
def test_xgrid_axes(ds):
69-
grid = XGrid.from_dataset(ds)
70+
grid = XGrid.from_dataset(ds, mesh="flat")
7071
assert grid.axes == ["Z", "Y", "X"]
7172

7273

@@ -80,7 +81,7 @@ def test_uxgrid_mesh(ds, mesh):
8081
@pytest.mark.parametrize("ds", [datasets["ds_2d_left"]])
8182
def test_transpose_xfield_data_to_tzyx(ds):
8283
da = ds["data_g"]
83-
grid = XGrid.from_dataset(ds)
84+
grid = XGrid.from_dataset(ds, mesh="flat")
8485

8586
all_combinations = (itertools.combinations(da.dims, n) for n in range(len(da.dims)))
8687
all_combinations = itertools.chain(*all_combinations)
@@ -93,7 +94,7 @@ def test_transpose_xfield_data_to_tzyx(ds):
9394

9495
@pytest.mark.parametrize("ds", [datasets["ds_2d_left"]])
9596
def test_xgrid_get_axis_dim(ds):
96-
grid = XGrid.from_dataset(ds)
97+
grid = XGrid.from_dataset(ds, mesh="flat")
9798
assert grid.get_axis_dim("Z") == Z - 1
9899
assert grid.get_axis_dim("Y") == Y - 1
99100
assert grid.get_axis_dim("X") == X - 1
@@ -113,15 +114,15 @@ def test_invalid_lon_lat():
113114
ValueError,
114115
match=".*is defined on the center of the grid, but must be defined on the F points\.",
115116
):
116-
XGrid.from_dataset(ds)
117+
XGrid.from_dataset(ds, mesh="flat")
117118

118119
ds = datasets["ds_2d_left"].copy()
119120
ds["lon"], _ = xr.broadcast(ds["YG"], ds["XG"])
120121
with pytest.raises(
121122
ValueError,
122123
match=".*have different dimensionalities\.",
123124
):
124-
XGrid.from_dataset(ds)
125+
XGrid.from_dataset(ds, mesh="flat")
125126

126127
ds = datasets["ds_2d_left"].copy()
127128
ds["lon"], ds["lat"] = xr.broadcast(ds["YG"], ds["XG"])
@@ -131,20 +132,20 @@ def test_invalid_lon_lat():
131132
ValueError,
132133
match=".*must be defined on the X and Y axes and transposed to have dimensions in order of Y, X\.",
133134
):
134-
XGrid.from_dataset(ds)
135+
XGrid.from_dataset(ds, mesh="flat")
135136

136137

137138
def test_invalid_depth():
138139
ds = datasets["ds_2d_left"].copy()
139140
ds = ds.reindex({"ZG": ds.ZG[::-1]})
140141

141142
with pytest.raises(ValueError, match="Depth DataArray .* must be strictly increasing*"):
142-
XGrid.from_dataset(ds)
143+
XGrid.from_dataset(ds, mesh="flat")
143144

144145

145146
def test_dim_without_axis():
146147
ds = xr.Dataset({"z1d": (["depth"], [0])}, coords={"depth": [0]})
147-
grid = XGrid.from_dataset(ds)
148+
grid = XGrid.from_dataset(ds, mesh="flat")
148149
with pytest.raises(ValueError, match='Dimension "depth" has no axis attribute*'):
149150
Field("z1d", ds["z1d"], grid, XLinear)
150151

@@ -155,7 +156,7 @@ def test_vertical1D_field():
155156
{"z1d": (["depth"], np.linspace(0, 10, nz))},
156157
coords={"depth": (["depth"], np.linspace(0, 1, nz), {"axis": "Z"})},
157158
)
158-
grid = XGrid.from_dataset(ds)
159+
grid = XGrid.from_dataset(ds, mesh="flat")
159160
field = Field("z1d", ds["z1d"], grid, XLinear)
160161

161162
assert field.eval(np.timedelta64(0, "s"), 0.45, 0, 0) == 4.5
@@ -167,7 +168,7 @@ def test_time1D_field():
167168
{"t1d": (["time"], np.arange(0, len(timerange)))},
168169
coords={"time": (["time"], timerange, {"axis": "T"})},
169170
)
170-
grid = XGrid.from_dataset(ds)
171+
grid = XGrid.from_dataset(ds, mesh="flat")
171172
field = Field("t1d", ds["t1d"], grid, XLinear)
172173

173174
assert field.eval(np.datetime64("2000-01-10T12:00:00"), -20, 5, 6) == 9.5
@@ -181,7 +182,7 @@ def test_time1D_field():
181182
],
182183
) # for key, ds in datasets.items()])
183184
def test_xgrid_search_cpoints(ds):
184-
grid = XGrid.from_dataset(ds)
185+
grid = XGrid.from_dataset(ds, mesh="flat")
185186
lat_array, lon_array = get_2d_fpoint_mesh(grid)
186187
lat_array, lon_array = corner_to_cell_center_points(lat_array, lon_array)
187188

@@ -299,7 +300,7 @@ def test_search_1d_array_some_out_of_bounds(array, x, expected_xi):
299300
)
300301
def test_xgrid_localize_zero_position(ds, da_name, expected):
301302
"""Test localize function using left and right datasets."""
302-
grid = XGrid.from_dataset(ds)
303+
grid = XGrid.from_dataset(ds, mesh="flat")
303304
da = ds[da_name]
304305
position = grid.search(0, 0, 0)
305306

0 commit comments

Comments
 (0)