|
34 | 34 | import numpy as np |
35 | 35 |
|
36 | 36 | from pyeyesweb.data_models.sliding_window import SlidingWindow |
37 | | -from pyeyesweb.utils.signal_processing import bandpass_filter, compute_hilbert_phases |
| 37 | +from pyeyesweb.utils.signal_processing import bandpass_filter, compute_hilbert_phases, validate_filter_params |
38 | 38 | from pyeyesweb.utils.math_utils import compute_phase_locking_value, center_signals |
| 39 | +from pyeyesweb.utils.validators import validate_integer, validate_boolean, validate_numeric |
39 | 40 |
|
40 | 41 |
|
41 | 42 | class Synchronization: |
@@ -125,45 +126,31 @@ class Synchronization: |
125 | 126 | """ |
126 | 127 |
|
127 | 128 | def __init__(self, sensitivity=100, output_phase=False, filter_params=None, phase_threshold=0.7): |
128 | | - # Validate sensitivity |
129 | | - if not isinstance(sensitivity, int): |
130 | | - raise TypeError(f"sensitivity must be an integer, got {type(sensitivity).__name__}") |
131 | | - if sensitivity <= 0: |
132 | | - raise ValueError(f"sensitivity must be positive, got {sensitivity}") |
133 | | - if sensitivity > 10000: # Reasonable upper limit |
134 | | - raise ValueError(f"sensitivity too large ({sensitivity}), maximum is 10,000") |
135 | | - |
136 | | - # Validate output_phase |
137 | | - if not isinstance(output_phase, bool): |
138 | | - raise TypeError(f"output_phase must be boolean, got {type(output_phase).__name__}") |
139 | | - |
140 | | - # Validate phase_threshold |
141 | | - if not isinstance(phase_threshold, (int, float)): |
142 | | - raise TypeError(f"phase_threshold must be a number, got {type(phase_threshold).__name__}") |
143 | | - if not 0 <= phase_threshold <= 1: |
144 | | - raise ValueError(f"phase_threshold must be between 0 and 1, got {phase_threshold}") |
| 129 | + # Validate sensitivity using centralized validator |
| 130 | + sensitivity = validate_integer(sensitivity, 'sensitivity', min_val=1, max_val=10000) |
| 131 | + |
| 132 | + # Validate output_phase using centralized validator |
| 133 | + self.output_phase = validate_boolean(output_phase, 'output_phase') |
| 134 | + |
| 135 | + # Validate phase_threshold using centralized validator with range check |
| 136 | + self.phase_threshold = validate_numeric(phase_threshold, 'phase_threshold', min_val=0, max_val=1) |
145 | 137 |
|
146 | 138 | # Validate filter_params if provided |
147 | 139 | if filter_params is not None: |
148 | 140 | if not isinstance(filter_params, (tuple, list)): |
149 | 141 | raise TypeError(f"filter_params must be a tuple or list, got {type(filter_params).__name__}") |
150 | 142 | if len(filter_params) != 3: |
151 | 143 | raise ValueError(f"filter_params must have 3 elements (lowcut, highcut, fs), got {len(filter_params)}") |
152 | | - lowcut, highcut, fs = filter_params |
153 | 144 | if not all(isinstance(x, (int, float)) for x in filter_params): |
154 | 145 | raise TypeError("filter_params must contain only numbers") |
155 | | - if lowcut <= 0 or highcut <= 0 or fs <= 0: |
156 | | - raise ValueError("filter_params frequencies must be positive") |
157 | | - if lowcut >= highcut: |
158 | | - raise ValueError(f"lowcut ({lowcut}) must be less than highcut ({highcut})") |
159 | | - if highcut >= fs / 2: |
160 | | - raise ValueError(f"highcut ({highcut}) must be less than Nyquist frequency ({fs/2})") |
| 146 | + # Use centralized filter validation |
| 147 | + lowcut, highcut, fs = validate_filter_params(*filter_params) |
| 148 | + # Store the validated parameters |
| 149 | + filter_params = (lowcut, highcut, fs) |
161 | 150 |
|
162 | 151 | self.plv_history = deque(maxlen=sensitivity) |
163 | 152 | self._history_lock = threading.Lock() |
164 | | - self.output_phase = output_phase |
165 | 153 | self.filter_params = filter_params |
166 | | - self.phase_threshold = phase_threshold |
167 | 154 |
|
168 | 155 |
|
169 | 156 | def compute_synchronization(self, signals: SlidingWindow): |
|
0 commit comments