Skip to content

Commit 3e401e2

Browse files
committed
refactor tests and apply DRY principle.
1 parent db9a4f3 commit 3e401e2

1 file changed

Lines changed: 53 additions & 52 deletions

File tree

tests/unit/test_segy_grid_overrides.py

Lines changed: 53 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from __future__ import annotations
55

6+
from typing import Any
7+
68
import numpy.typing as npt
79
import pytest
810
from numpy import arange
@@ -23,6 +25,27 @@
2325
RECEIVERS = arange(1, 6, dtype="int32")
2426

2527

28+
def run_override(
29+
grid_overrides: dict[str, Any],
30+
index_names: tuple[str, ...],
31+
headers: dict[str, npt.NDArray],
32+
chunksize: tuple[int, ...] | None = None,
33+
) -> tuple[dict[str, Any], tuple[str], tuple[int]]:
34+
"""Initialize and run overrider."""
35+
overrider = GridOverrider()
36+
return overrider.run(headers, index_names, grid_overrides, chunksize)
37+
38+
39+
def get_dims(headers: dict[str, npt.NDArray]) -> list[Dimension]:
40+
"""Get list of Dimensions from headers."""
41+
dims = []
42+
for index_name, index_coords in headers.items():
43+
dim_unique = unique(index_coords)
44+
dims.append(Dimension(coords=dim_unique, name=index_name))
45+
46+
return dims
47+
48+
2649
@pytest.fixture
2750
def mock_streamer_headers() -> dict[str, npt.NDArray]:
2851
"""Generate dictionary of mocked streamer index headers."""
@@ -48,55 +71,49 @@ class TestAutoGridOverrides:
4871

4972
def test_duplicates(self, mock_streamer_headers: npt.NDArray) -> None:
5073
"""Test the HasDuplicates Grid Override command."""
74+
index_names = ("shot", "cable")
5175
grid_overrides = {"HasDuplicates": True}
5276

53-
# mock_streamer_headers["trace"] = mock_streamer_headers["channel"]
5477
# Remove channel header
5578
del mock_streamer_headers["channel"]
56-
index_names = ("shot", "cable")
5779
chunksize = (4, 4, 8)
5880

59-
overrider = GridOverrider()
60-
results, new_names, new_chunks = overrider.run(
61-
mock_streamer_headers, index_names, grid_overrides, chunksize
81+
new_headers, new_names, new_chunks = run_override(
82+
grid_overrides,
83+
index_names,
84+
mock_streamer_headers,
85+
chunksize,
6286
)
6387

6488
assert new_names == ("shot", "cable", "trace")
6589
assert new_chunks == (4, 4, 1, 8)
6690

67-
dims = []
68-
for index_name, index_coords in results.items():
69-
dim_unique = unique(index_coords)
70-
dims.append(Dimension(coords=dim_unique, name=index_name))
91+
dims = get_dims(new_headers)
7192

7293
assert_array_equal(dims[0].coords, SHOTS)
7394
assert_array_equal(dims[1].coords, CABLES)
7495
assert_array_equal(dims[2].coords, RECEIVERS)
7596

7697
def test_non_binned(self, mock_streamer_headers: npt.NDArray) -> None:
7798
"""Test the NonBinned Grid Override command."""
99+
index_names = ("shot", "cable")
78100
grid_overrides = {"NonBinned": True, "chunksize": 4}
79101

80102
# Remove channel header
81103
del mock_streamer_headers["channel"]
82-
index_names = (
83-
"shot",
84-
"cable",
85-
)
86104
chunksize = (4, 4, 8)
87105

88-
overrider = GridOverrider()
89-
results, new_names, new_chunks = overrider.run(
90-
mock_streamer_headers, index_names, grid_overrides, chunksize
106+
new_headers, new_names, new_chunks = run_override(
107+
grid_overrides,
108+
index_names,
109+
mock_streamer_headers,
110+
chunksize,
91111
)
92112

93113
assert new_names == ("shot", "cable", "trace")
94114
assert new_chunks == (4, 4, 4, 8)
95115

96-
dims = []
97-
for index_name, index_coords in results.items():
98-
dim_unique = unique(index_coords)
99-
dims.append(Dimension(coords=dim_unique, name=index_name))
116+
dims = get_dims(new_headers)
100117

101118
assert_array_equal(dims[0].coords, SHOTS)
102119
assert_array_equal(dims[1].coords, CABLES)
@@ -108,45 +125,35 @@ class TestStreamerGridOverrides:
108125

109126
def test_channel_wrap(self, mock_streamer_headers: npt.NDArray) -> None:
110127
"""Test the ChannelWrap command."""
111-
grid_overrides = {"ChannelWrap": True, "ChannelsPerCable": len(RECEIVERS)}
112128
index_names = ("shot", "cable", "channel")
113-
chunksize = None
114-
overrider = GridOverrider()
115-
results, new_names, new_chunks = overrider.run(
116-
mock_streamer_headers, index_names, grid_overrides, chunksize
129+
grid_overrides = {"ChannelWrap": True, "ChannelsPerCable": len(RECEIVERS)}
130+
131+
new_headers, new_names, new_chunks = run_override(
132+
grid_overrides, index_names, mock_streamer_headers
117133
)
118134

119135
assert new_names == index_names
120136
assert new_chunks is None
121-
dims = []
122-
for index_name, index_coords in results.items():
123-
dim_unique = unique(index_coords)
124-
dims.append(Dimension(coords=dim_unique, name=index_name))
125-
126-
assert_array_equal(dims[0], SHOTS)
127-
assert_array_equal(dims[1], CABLES)
128-
assert_array_equal(dims[2], RECEIVERS)
137+
138+
dims = get_dims(new_headers)
139+
129140
assert_array_equal(dims[0].coords, SHOTS)
130141
assert_array_equal(dims[1].coords, CABLES)
131142
assert_array_equal(dims[2].coords, RECEIVERS)
132143

133144
def test_calculate_cable(self, mock_streamer_headers: npt.NDArray) -> None:
134145
"""Test the CalculateCable command."""
135-
grid_overrides = {"CalculateCable": True, "ChannelsPerCable": len(RECEIVERS)}
136146
index_names = ("shot", "cable", "channel")
137-
chunksize = None
138-
overrider = GridOverrider()
139-
results, new_names, new_chunks = overrider.run(
140-
mock_streamer_headers, index_names, grid_overrides, chunksize
147+
grid_overrides = {"CalculateCable": True, "ChannelsPerCable": len(RECEIVERS)}
148+
149+
new_headers, new_names, new_chunks = run_override(
150+
grid_overrides, index_names, mock_streamer_headers
141151
)
142152

143153
assert new_names == index_names
144154
assert new_chunks is None
145155

146-
dims = []
147-
for index_name, index_coords in results.items():
148-
dim_unique = unique(index_coords)
149-
dims.append(Dimension(coords=dim_unique, name=index_name))
156+
dims = get_dims(new_headers)
150157

151158
# We need channels because unwrap isn't done here
152159
channels = unique(mock_streamer_headers["channel"])
@@ -160,27 +167,21 @@ def test_calculate_cable(self, mock_streamer_headers: npt.NDArray) -> None:
160167

161168
def test_wrap_and_calc_cable(self, mock_streamer_headers: npt.NDArray) -> None:
162169
"""Test the combined ChannelWrap and CalculateCable commands."""
170+
index_names = ("shot", "cable", "channel")
163171
grid_overrides = {
164172
"CalculateCable": True,
165173
"ChannelWrap": True,
166174
"ChannelsPerCable": len(RECEIVERS),
167175
}
168176

169-
index_names = ("shot", "cable", "channel")
170-
chunksize = None
171-
overrider = GridOverrider()
172-
results, new_names, new_chunks = overrider.run(
173-
mock_streamer_headers, index_names, grid_overrides, chunksize
177+
new_headers, new_names, new_chunks = run_override(
178+
grid_overrides, index_names, mock_streamer_headers
174179
)
175180

176181
assert new_names == index_names
177182
assert new_chunks is None
178183

179-
dims = []
180-
for index_name, index_coords in results.items():
181-
dim_unique = unique(index_coords)
182-
dims.append(Dimension(coords=dim_unique, name=index_name))
183-
184+
dims = get_dims(new_headers)
184185
# We reset the cables to start from 1.
185186
cables = arange(1, len(CABLES) + 1, dtype="uint32")
186187

0 commit comments

Comments
 (0)