33
44from __future__ import annotations
55
6+ from typing import Any
7+
68import numpy .typing as npt
79import pytest
810from numpy import arange
2325RECEIVERS = arange (1 , 6 , dtype = "int32" )
2426
2527
28+ def run_override (
29+ grid_overrides : dict [str , Any ],
30+ index_names : tuple [str , ...],
31+ headers : dict [str , npt .NDArray ],
32+ chunksize : tuple [int , ...] | None = None ,
33+ ) -> tuple [dict [str , Any ], tuple [str ], tuple [int ]]:
34+ """Initialize and run overrider."""
35+ overrider = GridOverrider ()
36+ return overrider .run (headers , index_names , grid_overrides , chunksize )
37+
38+
39+ def get_dims (headers : dict [str , npt .NDArray ]) -> list [Dimension ]:
40+ """Get list of Dimensions from headers."""
41+ dims = []
42+ for index_name , index_coords in headers .items ():
43+ dim_unique = unique (index_coords )
44+ dims .append (Dimension (coords = dim_unique , name = index_name ))
45+
46+ return dims
47+
48+
2649@pytest .fixture
2750def mock_streamer_headers () -> dict [str , npt .NDArray ]:
2851 """Generate dictionary of mocked streamer index headers."""
@@ -48,55 +71,49 @@ class TestAutoGridOverrides:
4871
4972 def test_duplicates (self , mock_streamer_headers : npt .NDArray ) -> None :
5073 """Test the HasDuplicates Grid Override command."""
74+ index_names = ("shot" , "cable" )
5175 grid_overrides = {"HasDuplicates" : True }
5276
53- # mock_streamer_headers["trace"] = mock_streamer_headers["channel"]
5477 # Remove channel header
5578 del mock_streamer_headers ["channel" ]
56- index_names = ("shot" , "cable" )
5779 chunksize = (4 , 4 , 8 )
5880
59- overrider = GridOverrider ()
60- results , new_names , new_chunks = overrider .run (
61- mock_streamer_headers , index_names , grid_overrides , chunksize
81+ new_headers , new_names , new_chunks = run_override (
82+ grid_overrides ,
83+ index_names ,
84+ mock_streamer_headers ,
85+ chunksize ,
6286 )
6387
6488 assert new_names == ("shot" , "cable" , "trace" )
6589 assert new_chunks == (4 , 4 , 1 , 8 )
6690
67- dims = []
68- for index_name , index_coords in results .items ():
69- dim_unique = unique (index_coords )
70- dims .append (Dimension (coords = dim_unique , name = index_name ))
91+ dims = get_dims (new_headers )
7192
7293 assert_array_equal (dims [0 ].coords , SHOTS )
7394 assert_array_equal (dims [1 ].coords , CABLES )
7495 assert_array_equal (dims [2 ].coords , RECEIVERS )
7596
7697 def test_non_binned (self , mock_streamer_headers : npt .NDArray ) -> None :
7798 """Test the NonBinned Grid Override command."""
99+ index_names = ("shot" , "cable" )
78100 grid_overrides = {"NonBinned" : True , "chunksize" : 4 }
79101
80102 # Remove channel header
81103 del mock_streamer_headers ["channel" ]
82- index_names = (
83- "shot" ,
84- "cable" ,
85- )
86104 chunksize = (4 , 4 , 8 )
87105
88- overrider = GridOverrider ()
89- results , new_names , new_chunks = overrider .run (
90- mock_streamer_headers , index_names , grid_overrides , chunksize
106+ new_headers , new_names , new_chunks = run_override (
107+ grid_overrides ,
108+ index_names ,
109+ mock_streamer_headers ,
110+ chunksize ,
91111 )
92112
93113 assert new_names == ("shot" , "cable" , "trace" )
94114 assert new_chunks == (4 , 4 , 4 , 8 )
95115
96- dims = []
97- for index_name , index_coords in results .items ():
98- dim_unique = unique (index_coords )
99- dims .append (Dimension (coords = dim_unique , name = index_name ))
116+ dims = get_dims (new_headers )
100117
101118 assert_array_equal (dims [0 ].coords , SHOTS )
102119 assert_array_equal (dims [1 ].coords , CABLES )
@@ -108,45 +125,35 @@ class TestStreamerGridOverrides:
108125
109126 def test_channel_wrap (self , mock_streamer_headers : npt .NDArray ) -> None :
110127 """Test the ChannelWrap command."""
111- grid_overrides = {"ChannelWrap" : True , "ChannelsPerCable" : len (RECEIVERS )}
112128 index_names = ("shot" , "cable" , "channel" )
113- chunksize = None
114- overrider = GridOverrider ()
115- results , new_names , new_chunks = overrider . run (
116- mock_streamer_headers , index_names , grid_overrides , chunksize
129+ grid_overrides = { "ChannelWrap" : True , "ChannelsPerCable" : len ( RECEIVERS )}
130+
131+ new_headers , new_names , new_chunks = run_override (
132+ grid_overrides , index_names , mock_streamer_headers
117133 )
118134
119135 assert new_names == index_names
120136 assert new_chunks is None
121- dims = []
122- for index_name , index_coords in results .items ():
123- dim_unique = unique (index_coords )
124- dims .append (Dimension (coords = dim_unique , name = index_name ))
125-
126- assert_array_equal (dims [0 ], SHOTS )
127- assert_array_equal (dims [1 ], CABLES )
128- assert_array_equal (dims [2 ], RECEIVERS )
137+
138+ dims = get_dims (new_headers )
139+
129140 assert_array_equal (dims [0 ].coords , SHOTS )
130141 assert_array_equal (dims [1 ].coords , CABLES )
131142 assert_array_equal (dims [2 ].coords , RECEIVERS )
132143
133144 def test_calculate_cable (self , mock_streamer_headers : npt .NDArray ) -> None :
134145 """Test the CalculateCable command."""
135- grid_overrides = {"CalculateCable" : True , "ChannelsPerCable" : len (RECEIVERS )}
136146 index_names = ("shot" , "cable" , "channel" )
137- chunksize = None
138- overrider = GridOverrider ()
139- results , new_names , new_chunks = overrider . run (
140- mock_streamer_headers , index_names , grid_overrides , chunksize
147+ grid_overrides = { "CalculateCable" : True , "ChannelsPerCable" : len ( RECEIVERS )}
148+
149+ new_headers , new_names , new_chunks = run_override (
150+ grid_overrides , index_names , mock_streamer_headers
141151 )
142152
143153 assert new_names == index_names
144154 assert new_chunks is None
145155
146- dims = []
147- for index_name , index_coords in results .items ():
148- dim_unique = unique (index_coords )
149- dims .append (Dimension (coords = dim_unique , name = index_name ))
156+ dims = get_dims (new_headers )
150157
151158 # We need channels because unwrap isn't done here
152159 channels = unique (mock_streamer_headers ["channel" ])
@@ -160,27 +167,21 @@ def test_calculate_cable(self, mock_streamer_headers: npt.NDArray) -> None:
160167
161168 def test_wrap_and_calc_cable (self , mock_streamer_headers : npt .NDArray ) -> None :
162169 """Test the combined ChannelWrap and CalculateCable commands."""
170+ index_names = ("shot" , "cable" , "channel" )
163171 grid_overrides = {
164172 "CalculateCable" : True ,
165173 "ChannelWrap" : True ,
166174 "ChannelsPerCable" : len (RECEIVERS ),
167175 }
168176
169- index_names = ("shot" , "cable" , "channel" )
170- chunksize = None
171- overrider = GridOverrider ()
172- results , new_names , new_chunks = overrider .run (
173- mock_streamer_headers , index_names , grid_overrides , chunksize
177+ new_headers , new_names , new_chunks = run_override (
178+ grid_overrides , index_names , mock_streamer_headers
174179 )
175180
176181 assert new_names == index_names
177182 assert new_chunks is None
178183
179- dims = []
180- for index_name , index_coords in results .items ():
181- dim_unique = unique (index_coords )
182- dims .append (Dimension (coords = dim_unique , name = index_name ))
183-
184+ dims = get_dims (new_headers )
184185 # We reset the cables to start from 1.
185186 cables = arange (1 , len (CABLES ) + 1 , dtype = "uint32" )
186187
0 commit comments