Skip to content

Commit d1ea34c

Browse files
Refactor modules to use centralized validators
1 parent e700afc commit d1ea34c

3 files changed

Lines changed: 26 additions & 62 deletions

File tree

pyeyesweb/data_models/sliding_window.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import threading
33
import numpy as np
44
from typing import Optional, Union
5+
from pyeyesweb.utils.validators import validate_integer
56

67

78
class SlidingWindow:
@@ -59,30 +60,14 @@ def max_length(self, value: int):
5960
self._resize(old_max_length)
6061

6162
def __init__(self, max_length: int, n_columns: int):
62-
# Validate inputs
63-
if not isinstance(max_length, int):
64-
raise TypeError(f"max_length must be an integer, got {type(max_length).__name__}")
65-
if not isinstance(n_columns, int):
66-
raise TypeError(f"n_columns must be an integer, got {type(n_columns).__name__}")
67-
68-
if max_length <= 0:
69-
raise ValueError(f"max_length must be positive, got {max_length}")
70-
if n_columns <= 0:
71-
raise ValueError(f"n_columns must be positive, got {n_columns}")
72-
73-
# Reasonable limits to prevent memory exhaustion
74-
if max_length > 10_000_000: # Have added 10 million samples
75-
raise ValueError(f"max_length too large ({max_length}), maximum is 10,000,000")
76-
if n_columns > 10_000: # 10k features
77-
raise ValueError(f"n_columns too large ({n_columns}), maximum is 10,000")
63+
# Validate inputs using centralized validators
64+
self._max_length = validate_integer(max_length, 'max_length', min_val=1, max_val=10_000_000)
65+
self._n_columns = validate_integer(n_columns, 'n_columns', min_val=1, max_val=10_000)
7866

7967
self._lock = threading.RLock()
8068

81-
self._max_length = max_length
82-
self._n_columns = n_columns
83-
84-
self._buffer = np.empty((max_length, n_columns), dtype=np.float32)
85-
self._timestamp = np.empty(max_length, dtype=np.float64)
69+
self._buffer = np.empty((self._max_length, self._n_columns), dtype=np.float32)
70+
self._timestamp = np.empty(self._max_length, dtype=np.float64)
8671

8772
self._start = 0
8873
self._size = 0

pyeyesweb/mid_level/smoothness.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pyeyesweb.data_models.sliding_window import SlidingWindow
1717
from pyeyesweb.utils.signal_processing import apply_savgol_filter
1818
from pyeyesweb.utils.math_utils import compute_sparc, compute_jerk_rms, normalize_signal
19+
from pyeyesweb.utils.validators import validate_numeric, validate_boolean
1920

2021

2122
class Smoothness:
@@ -64,20 +65,11 @@ class Smoothness:
6465
"""
6566

6667
def __init__(self, rate_hz=50.0, use_filter=True):
67-
# Validate rate_hz
68-
if not isinstance(rate_hz, (int, float)):
69-
raise TypeError(f"rate_hz must be a number, got {type(rate_hz).__name__}")
70-
if rate_hz <= 0:
71-
raise ValueError(f"rate_hz must be positive, got {rate_hz}")
72-
if rate_hz > 100000: # 100 kHz is a reasonable upper limit
73-
raise ValueError(f"rate_hz too high ({rate_hz} Hz), maximum is 100,000 Hz")
74-
75-
# Validate use_filter
76-
if not isinstance(use_filter, bool):
77-
raise TypeError(f"use_filter must be boolean, got {type(use_filter).__name__}")
78-
79-
self.rate_hz = float(rate_hz) # Ensure it's a float
80-
self.use_filter = use_filter
68+
# Validate rate_hz using centralized validator
69+
self.rate_hz = validate_numeric(rate_hz, 'rate_hz', min_val=0.01, max_val=100000)
70+
71+
# Validate use_filter using centralized validator
72+
self.use_filter = validate_boolean(use_filter, 'use_filter')
8173

8274
def _filter_signal(self, signal):
8375
"""Apply Savitzky-Golay filter if enabled and enough data.

pyeyesweb/sync.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@
3434
import numpy as np
3535

3636
from pyeyesweb.data_models.sliding_window import SlidingWindow
37-
from pyeyesweb.utils.signal_processing import bandpass_filter, compute_hilbert_phases
37+
from pyeyesweb.utils.signal_processing import bandpass_filter, compute_hilbert_phases, validate_filter_params
3838
from pyeyesweb.utils.math_utils import compute_phase_locking_value, center_signals
39+
from pyeyesweb.utils.validators import validate_integer, validate_boolean, validate_numeric
3940

4041

4142
class Synchronization:
@@ -125,45 +126,31 @@ class Synchronization:
125126
"""
126127

127128
def __init__(self, sensitivity=100, output_phase=False, filter_params=None, phase_threshold=0.7):
128-
# Validate sensitivity
129-
if not isinstance(sensitivity, int):
130-
raise TypeError(f"sensitivity must be an integer, got {type(sensitivity).__name__}")
131-
if sensitivity <= 0:
132-
raise ValueError(f"sensitivity must be positive, got {sensitivity}")
133-
if sensitivity > 10000: # Reasonable upper limit
134-
raise ValueError(f"sensitivity too large ({sensitivity}), maximum is 10,000")
135-
136-
# Validate output_phase
137-
if not isinstance(output_phase, bool):
138-
raise TypeError(f"output_phase must be boolean, got {type(output_phase).__name__}")
139-
140-
# Validate phase_threshold
141-
if not isinstance(phase_threshold, (int, float)):
142-
raise TypeError(f"phase_threshold must be a number, got {type(phase_threshold).__name__}")
143-
if not 0 <= phase_threshold <= 1:
144-
raise ValueError(f"phase_threshold must be between 0 and 1, got {phase_threshold}")
129+
# Validate sensitivity using centralized validator
130+
sensitivity = validate_integer(sensitivity, 'sensitivity', min_val=1, max_val=10000)
131+
132+
# Validate output_phase using centralized validator
133+
self.output_phase = validate_boolean(output_phase, 'output_phase')
134+
135+
# Validate phase_threshold using centralized validator with range check
136+
self.phase_threshold = validate_numeric(phase_threshold, 'phase_threshold', min_val=0, max_val=1)
145137

146138
# Validate filter_params if provided
147139
if filter_params is not None:
148140
if not isinstance(filter_params, (tuple, list)):
149141
raise TypeError(f"filter_params must be a tuple or list, got {type(filter_params).__name__}")
150142
if len(filter_params) != 3:
151143
raise ValueError(f"filter_params must have 3 elements (lowcut, highcut, fs), got {len(filter_params)}")
152-
lowcut, highcut, fs = filter_params
153144
if not all(isinstance(x, (int, float)) for x in filter_params):
154145
raise TypeError("filter_params must contain only numbers")
155-
if lowcut <= 0 or highcut <= 0 or fs <= 0:
156-
raise ValueError("filter_params frequencies must be positive")
157-
if lowcut >= highcut:
158-
raise ValueError(f"lowcut ({lowcut}) must be less than highcut ({highcut})")
159-
if highcut >= fs / 2:
160-
raise ValueError(f"highcut ({highcut}) must be less than Nyquist frequency ({fs/2})")
146+
# Use centralized filter validation
147+
lowcut, highcut, fs = validate_filter_params(*filter_params)
148+
# Store the validated parameters
149+
filter_params = (lowcut, highcut, fs)
161150

162151
self.plv_history = deque(maxlen=sensitivity)
163152
self._history_lock = threading.Lock()
164-
self.output_phase = output_phase
165153
self.filter_params = filter_params
166-
self.phase_threshold = phase_threshold
167154

168155

169156
def compute_synchronization(self, signals: SlidingWindow):

0 commit comments

Comments
 (0)