Skip to content

Commit bf2e41f

Browse files
committed
USe headers: HeaderSpec
1 parent 67a03ec commit bf2e41f

2 files changed

Lines changed: 63 additions & 46 deletions

File tree

src/mdio/creators/mdio.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
from typing import TYPE_CHECKING
66

7-
from segy.standards import get_segy_standard
8-
97
from mdio.api.io import _normalize_path
108
from mdio.api.io import to_mdio
119
from mdio.builder.template_registry import TemplateRegistry
@@ -17,6 +15,7 @@
1715
if TYPE_CHECKING:
1816
from pathlib import Path
1917

18+
from segy.schema import HeaderSpec
2019
from upath import UPath
2120
from xarray import Dataset as xr_Dataset
2221

@@ -28,7 +27,7 @@ def create_empty( # noqa PLR0913
2827
mdio_template_name: str,
2928
dimensions: list[Dimension],
3029
output_path: UPath | Path | str,
31-
create_headers: bool = False,
30+
headers: HeaderSpec | None = None,
3231
overwrite: bool = False,
3332
) -> None:
3433
"""A function that creates an empty MDIO v1 file with known dimensions.
@@ -37,7 +36,7 @@ def create_empty( # noqa PLR0913
3736
mdio_template_name: The MDIO template to use to define the dataset structure.
3837
dimensions: The dimensions of the MDIO file.
3938
output_path: The universal path for the output MDIO v1 file.
40-
create_headers: Whether to create a full set of SEG-Y v1.0 trace headers. Defaults to False.
39+
headers: SEG-Y v1.0 trace headers. Defaults to None.
4140
overwrite: Whether to overwrite the output file if it already exists. Defaults to False.
4241
4342
Raises:
@@ -49,7 +48,7 @@ def create_empty( # noqa PLR0913
4948
err = f"Output location '{output_path.as_posix()}' exists. Set `overwrite=True` if intended."
5049
raise FileExistsError(err)
5150

52-
header_dtype = to_structured_type(get_segy_standard(1.0).trace.header.dtype) if create_headers else None
51+
header_dtype = to_structured_type(headers.dtype) if headers else None
5352
grid = Grid(dims=dimensions)
5453
mdio_template = TemplateRegistry().get(mdio_template_name)
5554
mdio_ds: Dataset = mdio_template.build_dataset(

tests/integration/test_create_empty.py

Lines changed: 59 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,51 @@
33
from __future__ import annotations
44

55
import math
6-
from turtle import speed
76
from typing import TYPE_CHECKING
87

98
import numpy as np
109
import pytest
11-
from segy.standards import get_segy_standard
10+
from segy.schema import HeaderField
11+
from segy.schema import HeaderSpec
1212

13-
from mdio.builder.schemas.v1.units import LengthUnitEnum, LengthUnitModel, SpeedUnitEnum, SpeedUnitModel, TimeUnitEnum, TimeUnitModel
13+
from mdio.builder.schemas.v1.units import LengthUnitEnum
14+
from mdio.builder.schemas.v1.units import LengthUnitModel
15+
from mdio.builder.schemas.v1.units import SpeedUnitEnum
16+
from mdio.builder.schemas.v1.units import SpeedUnitModel
17+
from mdio.builder.schemas.v1.units import TimeUnitEnum
18+
from mdio.builder.schemas.v1.units import TimeUnitModel
1419

1520
if TYPE_CHECKING:
1621
from pathlib import Path
1722

1823
from xarray import Dataset as xr_Dataset
1924

20-
from mdio.builder.schemas.v1.stats import CenteredBinHistogram, SummaryStatistics
21-
from tests.integration.test_segy_roundtrip_teapot import text_header_teapot_dome
2225
from tests.integration.testing_helpers import get_values
2326
from tests.integration.testing_helpers import validate_variable
2427

2528
from mdio import __version__
26-
from mdio.api.io import open_mdio, to_mdio
29+
from mdio.api.io import open_mdio
30+
from mdio.api.io import to_mdio
31+
from mdio.builder.schemas.v1.stats import CenteredBinHistogram
32+
from mdio.builder.schemas.v1.stats import SummaryStatistics
2733
from mdio.core import Dimension
2834
from mdio.creators.mdio import create_empty
2935

3036

3137
class TestCreateEmptyPostStack3DTimeMdio:
3238
"""Tests for create_empty_mdio function."""
3339

40+
@classmethod
41+
def _get_header_spec(cls) -> HeaderSpec:
42+
"""Get the header spec for the MDIO dataset."""
43+
trace_header_fields = [
44+
HeaderField(name="inline", byte=17, format="int32"),
45+
HeaderField(name="crossline", byte=13, format="int32"),
46+
HeaderField(name="cdp_x", byte=181, format="int32"),
47+
HeaderField(name="cdp_y", byte=185, format="int32"),
48+
]
49+
return HeaderSpec(fields=trace_header_fields)
50+
3451
@classmethod
3552
def _validate_empty_mdio_dataset(cls, ds: xr_Dataset, has_headers: bool) -> None:
3653
"""Validate an empty MDIO dataset structure and content."""
@@ -49,7 +66,7 @@ def _validate_empty_mdio_dataset(cls, ds: xr_Dataset, has_headers: bool) -> None
4966
if has_headers:
5067
# Validate the headers (should be empty for empty dataset)
5168
# Infer the dtype from segy_spec and ignore endianness
52-
header_dtype = get_segy_standard(1.0).trace.header.dtype.newbyteorder("native")
69+
header_dtype = cls._get_header_spec().dtype.newbyteorder("native")
5370
validate_variable(ds, "headers", (200, 300), ("inline", "crossline"), header_dtype, None, None)
5471
else:
5572
assert "headers" not in ds.variables
@@ -72,12 +89,12 @@ def _create_empty_mdio(cls, create_headers: bool, output_path: Path, overwrite:
7289
Dimension(name="time", coords=range(0, 3000, 4)), # 0-3 seconds 4ms sample rate
7390
]
7491

75-
# Call create_empty_mdio
92+
headers = cls._get_header_spec() if create_headers else None
7693
create_empty(
7794
mdio_template_name="PostStack3DTime",
7895
dimensions=dims,
7996
output_path=output_path,
80-
create_headers=create_headers,
97+
headers=headers,
8198
overwrite=overwrite,
8299
)
83100

@@ -173,33 +190,31 @@ def test_overwrite_behavior(self, empty_mdio_dir: Path) -> None:
173190
assert not garbage_file.exists(), "Garbage file should have been overwritten"
174191
assert not garbage_dir.exists(), "Garbage directory should have been overwritten"
175192

176-
177193
def test_populate_empty_dataset(self, mdio_with_headers: Path) -> None:
178194
"""Test showing how to populate empty dataset."""
179-
180195
# Open an empty PostStack3DTime dataset with SEG-Y 1.0 headers
181196
# NOTES:
182-
# When this empty dataset was created from the 'PostStack3DTime' template and dimensions,
197+
# When this empty dataset was created from the 'PostStack3DTime' template and dimensions,
183198
# * 'inline', 'crossline', and 'time' dimension coordinate variables were created and pre-populated
184199
# * 'cdp_x', 'cdp_y' non-dimensional coordinate variables were created
185200
# * 'amplitude' variable was created (the name of this variable is specified in the template)
186201
# HACK: in this example, we will use this variable to store the velocity data
187-
# * 'trace_mask' variable was created and pre-populated with 'False' fill values
202+
# * 'trace_mask' variable was created and pre-populated with 'False' fill values
188203
# (all traces are marked as dead)
189-
# * 'headers' segy trace headers variable was created (if the dataset was created with create_headers=true)
204+
# * 'headers' segy trace headers variable was created (if the dataset was created with headers not None)
190205
# * dataset attribute called 'attributes' was created
191-
ds = open_mdio(mdio_with_headers)
206+
ds = open_mdio(mdio_with_headers)
192207

193-
# 1.A) Populate dataset's velocity
208+
# 1) Populate dataset's velocity
194209
var_name = ds.attrs["attributes"]["defaultVariableName"]
195210
velocity = ds[var_name]
196-
velocity[:5,:,:] = 1
197-
velocity[5:10,:,:] = 2
198-
velocity[50:100,:,:] = 3
199-
velocity[150:175,:,:] = -1
211+
velocity[:5, :, :] = 1
212+
velocity[5:10, :, :] = 2
213+
velocity[50:100, :, :] = 3
214+
velocity[150:175, :, :] = -1
200215

201-
# 1.B) Populate dataset's velocity statistics (optional)
202-
nonzero_samples = np.ma.masked_invalid(velocity, copy=False)
216+
# 2) Populate dataset's velocity statistics (optional)
217+
nonzero_samples = np.ma.masked_invalid(velocity, copy=False)
203218
stats = SummaryStatistics(
204219
count=nonzero_samples.count(),
205220
min=nonzero_samples.min(),
@@ -210,35 +225,38 @@ def test_populate_empty_dataset(self, mdio_with_headers: Path) -> None:
210225
)
211226
velocity.attrs["statsV1"] = stats.model_dump_json()
212227

213-
# 1.C) Set coordinate and data variable units (optional)
214-
ds.time["unitsV1"] = TimeUnitModel(time=TimeUnitEnum.MILLISECOND).model_dump_json()
215-
216-
ds.cdp_x.attrs["unitsV1"] = LengthUnitModel(length=LengthUnitEnum.FOOT).model_dump_json()
217-
ds.cdp_x.attrs["unitsV1"] = LengthUnitModel(length=LengthUnitEnum.FOOT).model_dump_json()
218-
219-
velocity.attrs["unitsV1"] = SpeedUnitModel(speed=SpeedUnitEnum.FEET_PER_SECOND).model_dump_json()
220-
221228
# 3) Populate the non-dimensional coordinate variables 'cdp_x' and 'cdp_y' (optional)
222-
origin = [270000, 3290000] # survey x, y origin
223-
inline_azimuth_rad = 0.523599 # survey orientation, in radians, from the north to the east (30 degrees)
224-
spacing = [50, 50] # survey inline, crossline spacing
225-
inline_grid, xline_grid = np.meshgrid(ds.inline.values, ds.crossline.values, indexing='ij')
229+
origin = [270000, 3290000] # survey x, y origin
230+
inline_azimuth_rad = 0.523599 # survey orientation, in radians, from the north to the east (30 degrees)
231+
spacing = [50, 50] # survey inline, crossline spacing
232+
inline_grid, xline_grid = np.meshgrid(ds.inline.values, ds.crossline.values, indexing="ij")
226233
sin_azimuth = math.sin(inline_azimuth_rad)
227234
cos_azimuth = math.cos(inline_azimuth_rad)
228235
ds.cdp_x[:] = origin[0] + inline_grid * spacing[0] * sin_azimuth + xline_grid * spacing[1] * cos_azimuth
229236
ds.cdp_y[:] = origin[1] + inline_grid * spacing[0] * cos_azimuth - xline_grid * spacing[1] * sin_azimuth
230237

231238
# 4) Populate dataset's trace mask (optional)
232-
ds.trace_mask[:] = ~np.isnan(velocity[:,:,0])
239+
ds.trace_mask[:] = ~np.isnan(velocity[:, :, 0])
233240

234-
# 5) Populate dataset's segy trace headers, if those were created (optional)
235-
if "headers" in ds.variables:
236-
ds.headers["cdp_x"][:] = ds.cdp_x
237-
ds.headers["cdp_y"][:] = ds.cdp_y
241+
# 5) Set coordinate and data variable units (optional)
242+
ds.time["unitsV1"] = TimeUnitModel(time=TimeUnitEnum.MILLISECOND).model_dump_json()
238243

239-
# 5) Create dataset's custom attributes (optional)
244+
ds.cdp_x.attrs["unitsV1"] = LengthUnitModel(length=LengthUnitEnum.FOOT).model_dump_json()
245+
ds.cdp_x.attrs["unitsV1"] = LengthUnitModel(length=LengthUnitEnum.FOOT).model_dump_json()
246+
247+
velocity.attrs["unitsV1"] = SpeedUnitModel(speed=SpeedUnitEnum.FEET_PER_SECOND).model_dump_json()
248+
249+
# 6) Populate dataset's segy trace headers, if those were created (optional)
250+
if "headers" in ds.variables:
251+
# numpy broadcasting (200, 1) array to (200, 300) array
252+
ds["headers"].values["inline"] = ds.inline.values[:, np.newaxis]
253+
# numpy broadcasting (1, 300) array to (200, 300) array
254+
ds["headers"].values["crossline"] = ds.crossline.values[np.newaxis, :]
255+
ds["headers"]["cdp_x"][:] = ds.cdp_x
256+
ds["headers"]["cdp_y"][:] = ds.cdp_y
257+
258+
# 7) Create dataset's custom attributes (optional)
240259
ds.attrs["attributes"]["createdBy"] = "John Doe"
241260

242261
output_path = mdio_with_headers.parent / "populated_empty.mdio"
243262
to_mdio(ds, output_path=output_path, mode="w", compute=True)
244-

0 commit comments

Comments
 (0)