Skip to content

Commit d7b6a1d

Browse files
committed
use design pattern for grid overrides, simplify logic
1 parent 2fa8acb commit d7b6a1d

5 files changed

Lines changed: 314 additions & 150 deletions

File tree

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[flake8]
22
select = B,B9,C,D,DAR,E,F,N,RST,S,W
33
ignore = E203,E501,RST201,RST203,RST301,W503
4-
max-line-length = 80
4+
max-line-length = 88
55
max-complexity = 10
66
docstring-convention = google
77
rst-roles = class,const,func,meth,mod,ref

src/mdio/segy/exceptions.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,49 @@
77
class InvalidSEGYFileError(MDIOError):
88
"""Raised when there is an IOError from segyio."""
99

10-
pass
10+
11+
class GridOverrideInputError(MDIOError):
12+
"""Raised when grid override parameters are not correct."""
13+
14+
15+
class GridOverrideUnknownError(GridOverrideInputError):
16+
"""Raised when grid override parameter is unknown."""
17+
18+
def __init__(self, command_name):
19+
"""Initialize with custom message."""
20+
self.command_name = command_name
21+
self.message = f"Unknown grid override: {command_name}"
22+
super().__init__(self.message)
23+
24+
25+
class GridOverrideKeysError(GridOverrideInputError):
26+
"""Raised when grid override is not compatible with required keys."""
27+
28+
def __init__(self, command_name, required_keys):
29+
"""Initialize with custom message."""
30+
self.command_name = command_name
31+
self.required_keys = required_keys
32+
self.message = f"{command_name} can only be used with {required_keys} keys."
33+
super().__init__(self.message)
34+
35+
36+
class GridOverrideMissingParameterError(GridOverrideInputError):
37+
"""Raised when grid override parameters are not correct."""
38+
39+
def __init__(self, command_name, missing_parameter):
40+
"""Initialize with custom message."""
41+
self.command_name = command_name
42+
self.missing_parameter = missing_parameter
43+
self.message = f"{command_name} requires {missing_parameter} parameter."
44+
super().__init__(self.message)
45+
46+
47+
class GridOverrideIncompatibleError(GridOverrideInputError):
48+
"""Raised when two grid overrides are incompatible."""
49+
50+
def __init__(self, first_command, second_command):
51+
"""Initialize with custom message."""
52+
self.first_command = first_command
53+
self.second_command = second_command
54+
self.message = f"{first_command} can't be used together with {second_command}."
55+
super().__init__(self.message)

src/mdio/segy/geometry.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
"""SEG-Y geometry handling functions."""
2+
3+
4+
from __future__ import annotations
5+
6+
import logging
7+
from abc import ABC
8+
from abc import abstractmethod
9+
from enum import Enum
10+
from enum import auto
11+
12+
import numpy as np
13+
import numpy.typing as npt
14+
15+
from mdio.segy.exceptions import GridOverrideIncompatibleError
16+
from mdio.segy.exceptions import GridOverrideKeysError
17+
from mdio.segy.exceptions import GridOverrideMissingParameterError
18+
from mdio.segy.exceptions import GridOverrideUnknownError
19+
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class StreamerShotGeometryType(Enum):
25+
r"""Shot geometry template types for streamer acquisition.
26+
27+
Configuration A:
28+
Cable 1 -> 1------------------20
29+
Cable 2 -> 1-----------------20
30+
. 1-----------------20
31+
. ⛴ ☆ 1-----------------20
32+
. 1-----------------20
33+
Cable 6 -> 1-----------------20
34+
Cable 7 -> 1-----------------20
35+
36+
37+
Configuration B:
38+
Cable 1 -> 1------------------20
39+
Cable 2 -> 21-----------------40
40+
. 41-----------------60
41+
. ⛴ ☆ 61-----------------80
42+
. 81----------------100
43+
Cable 6 -> 101---------------120
44+
Cable 7 -> 121---------------140
45+
46+
Configuration C:
47+
Cable ? -> / 1------------------20
48+
Cable ? -> / 21-----------------40
49+
. / 41-----------------60
50+
. ⛴ ☆ - 61-----------------80
51+
. \ 81----------------100
52+
Cable ? -> \ 101---------------120
53+
Cable ? -> \ 121---------------140
54+
"""
55+
56+
A = auto()
57+
B = auto()
58+
C = auto()
59+
60+
61+
def analyze_streamer_headers(
62+
index_headers: npt.NDArray,
63+
) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray, StreamerShotGeometryType]:
64+
"""Check input headers for SEG-Y input to help determine geometry.
65+
66+
This function reads in trace_qc_count headers and finds the unique cable values.
67+
The function then checks to make sure channel numbers for different cables do
68+
not overlap.
69+
70+
Args:
71+
index_headers: numpy array with index headers
72+
73+
Returns:
74+
tuple of unique_cables, cable_chan_min, cable_chan_max, geom_type
75+
"""
76+
# Find unique cable ids
77+
unique_cables = np.sort(np.unique(index_headers["cable"]))
78+
79+
# Find channel min and max values for each cable
80+
cable_chan_min = np.empty(unique_cables.shape)
81+
cable_chan_max = np.empty(unique_cables.shape)
82+
83+
for idx, cable in enumerate(unique_cables):
84+
cable_mask = index_headers["cable"] == cable
85+
current_cable = index_headers["channel"][cable_mask]
86+
87+
cable_chan_min[idx] = np.min(current_cable)
88+
cable_chan_max[idx] = np.max(current_cable)
89+
90+
# Check channel numbers do not overlap for case B
91+
geom_type = StreamerShotGeometryType.B
92+
for idx, cable in enumerate(unique_cables):
93+
min_val = cable_chan_min[idx]
94+
max_val = cable_chan_max[idx]
95+
for idx2, cable2 in enumerate(unique_cables):
96+
if cable2 == cable:
97+
continue
98+
99+
if cable_chan_min[idx2] < max_val and cable_chan_max[idx2] > min_val:
100+
geom_type = StreamerShotGeometryType.A
101+
102+
return unique_cables, cable_chan_min, cable_chan_max, geom_type
103+
104+
105+
class GridOverrideCommand(ABC):
106+
"""Abstract base class for grid override commands."""
107+
108+
@property
109+
@abstractmethod
110+
def required_keys(self) -> set:
111+
"""Get the set of required keys for the grid override command."""
112+
113+
@abstractmethod
114+
def validate(self, index_headers: npt.NDArray, **kwargs) -> None:
115+
"""Validate if this transform should run on the type of data."""
116+
117+
@abstractmethod
118+
def transform(self, index_headers: npt.NDArray, **kwargs) -> npt.NDArray:
119+
"""Perform the grid transform."""
120+
121+
@property
122+
def name(self) -> str:
123+
"""Convenience property to get the name of the command."""
124+
return self.__class__.__name__
125+
126+
def check_required_keys(self, index_headers: npt.NDArray) -> None:
127+
"""Check if all required keys are present in the index headers."""
128+
index_names = index_headers.keys()
129+
if not self.required_keys.issubset(index_names):
130+
raise GridOverrideKeysError(self.name, self.required_keys)
131+
132+
133+
class AutoChannelWrap(GridOverrideCommand):
134+
"""Automatically determine Streamer acquisition type."""
135+
136+
required_keys = {"shot", "cable", "channel"}
137+
138+
def validate(self, index_headers: npt.NDArray, **kwargs) -> None:
139+
"""Validate if this transform should run on the type of data."""
140+
self.check_required_keys(index_headers)
141+
142+
if "ChannelWrap" in kwargs:
143+
raise GridOverrideIncompatibleError(self.name, "ChannelWrap")
144+
145+
if "CalculateCable" in kwargs:
146+
raise GridOverrideIncompatibleError(self.name, "CalculateCable")
147+
148+
def transform(self, index_headers: npt.NDArray, **kwargs):
149+
"""Perform the grid transform."""
150+
self.validate(index_headers, **kwargs)
151+
152+
result = analyze_streamer_headers(index_headers)
153+
unique_cables, cable_chan_min, cable_chan_max, geom_type = result
154+
logger.info(f"Ingesting dataset as {geom_type.name}")
155+
156+
# TODO: Add strict=True and remove noqa when min Python is 3.10
157+
for cable, chan_min, chan_max in zip( # noqa: B905
158+
unique_cables, cable_chan_min, cable_chan_max
159+
):
160+
logger.info(
161+
f"Cable: {cable} has min chan: {chan_min} and max chan: {chan_max}"
162+
)
163+
164+
# This might be slow and potentially could be improved with a rewrite
165+
# to prevent so many lookups
166+
if geom_type == StreamerShotGeometryType.B:
167+
for idx, cable in enumerate(unique_cables):
168+
cable_idxs = np.where(index_headers["cable"][:] == cable)
169+
cc_min = cable_chan_min[idx]
170+
171+
index_headers["channel"][cable_idxs] = (
172+
index_headers["channel"][cable_idxs] - cc_min + 1
173+
)
174+
175+
return index_headers
176+
177+
178+
class ChannelWrap(GridOverrideCommand):
179+
"""Wrap channels to start from one at cable boundaries."""
180+
181+
required_keys = {"shot", "cable", "channel"}
182+
183+
def validate(self, index_headers: npt.NDArray, **kwargs) -> None:
184+
"""Validate if this transform should run on the type of data."""
185+
self.check_required_keys(index_headers)
186+
187+
if "ChannelsPerCable" not in kwargs:
188+
raise GridOverrideMissingParameterError(self.name, "ChannelsPerCable")
189+
190+
if "AutoCableChannel" in kwargs:
191+
raise GridOverrideIncompatibleError(self.name, "AutoCableChannel")
192+
193+
def transform(self, index_headers: npt.NDArray, **kwargs) -> npt.NDArray:
194+
"""Perform the grid transform."""
195+
self.validate(index_headers, **kwargs)
196+
197+
channels_per_cable = kwargs["ChannelsPerCable"]
198+
index_headers["channel"] = (
199+
index_headers["channel"] - 1
200+
) % channels_per_cable + 1
201+
202+
return index_headers
203+
204+
205+
class CalculateCable(GridOverrideCommand):
206+
"""Calculate cable numbers from unwrapped channels."""
207+
208+
required_keys = {"shot", "cable", "channel"}
209+
210+
def validate(self, index_headers: npt.NDArray, **kwargs) -> None:
211+
"""Validate if this transform should run on the type of data."""
212+
self.check_required_keys(index_headers)
213+
214+
if "ChannelsPerCable" not in kwargs:
215+
raise GridOverrideMissingParameterError(self.name, "ChannelsPerCable")
216+
217+
if "AutoCableChannel" in kwargs:
218+
raise GridOverrideIncompatibleError(self.name, "AutoCableChannel")
219+
220+
def transform(self, index_headers, **kwargs):
221+
"""Perform the grid transform."""
222+
channels_per_cable = kwargs["ChannelsPerCable"]
223+
index_headers["cable"] = (
224+
index_headers["channel"] - 1
225+
) // channels_per_cable + 1
226+
227+
return index_headers
228+
229+
230+
class GridOverrider:
231+
"""Executor for grid overrides.
232+
233+
We support a certain type of grid overrides, and they have to be
234+
implemented following the ABC's in this module.
235+
236+
This class applies the grid overrides if needed.
237+
"""
238+
239+
def __init__(self):
240+
"""Define allowed overrides here."""
241+
self.commands = {
242+
"AutoChannelWrap": AutoChannelWrap(),
243+
"CalculateCable": CalculateCable(),
244+
"ChannelWrap": ChannelWrap(),
245+
}
246+
247+
def run(
248+
self,
249+
index_headers: npt.NDArray,
250+
grid_overrides: dict[str, bool],
251+
) -> npt.NDArray:
252+
"""Run grid overrides and return result."""
253+
for override in grid_overrides:
254+
if override in self.commands:
255+
function = self.commands[override].transform
256+
index_headers = function(index_headers, grid_overrides=grid_overrides)
257+
else:
258+
raise GridOverrideUnknownError(override)
259+
260+
return index_headers

0 commit comments

Comments
 (0)