Skip to content

Commit 19b6b58

Browse files
committed
Attempt to address #1
1 parent 8fa4645 commit 19b6b58

3 files changed

Lines changed: 138 additions & 89 deletions

File tree

pyxdf/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
#
44
# License: BSD (2-clause)
55

6-
from pkg_resources import get_distribution, DistributionNotFound
76
try:
8-
__version__ = get_distribution(__name__).version
9-
except DistributionNotFound: # package is not installed
7+
from pkg_resources import get_distribution, DistributionNotFound
8+
9+
try:
10+
__version__ = get_distribution(__name__).version
11+
except DistributionNotFound: # package is not installed
12+
__version__ = None
13+
except ImportError: # pkg_resources is not available
1014
__version__ = None
1115
from .pyxdf import load_xdf, resolve_streams, match_streaminfos, align_streams
1216

pyxdf/align.py

Lines changed: 116 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -31,148 +31,178 @@ def _interpolate(
3131

3232

3333
def _shift_align(old_timestamps, old_timeseries, new_timestamps):
34+
# Convert inputs to numpy arrays
3435
old_timestamps = np.array(old_timestamps)
3536
old_timeseries = np.array(old_timeseries)
3637
new_timestamps = np.array(new_timestamps)
38+
3739
ts_last = old_timestamps[-1]
38-
ts_first = old_timestamps[0]
39-
source = list()
40-
target = list()
41-
new_timeseries = np.empty((
42-
new_timestamps.shape[0], # new sample count
43-
old_timeseries.shape[1], # old channel count
44-
), dtype=object)
45-
new_timeseries.fill(np.nan)
46-
too_old = list()
47-
too_young = list()
40+
ts_first = old_timestamps[0]
41+
42+
# Initialize variables
43+
source = []
44+
target = []
45+
46+
new_timeseries = np.full((new_timestamps.shape[0], old_timeseries.shape[1]), np.nan)
47+
48+
too_old = []
49+
too_young = []
50+
51+
# Loop through new timestamps to find the closest old timestamp
52+
# Handle timestamps outside of the segment (too young or too old) different from stamnps from within the segment
4853
for nix, nts in enumerate(new_timestamps):
49-
closest = (np.abs(old_timestamps - nts)).argmin()
50-
# remember the edge cases,
51-
if (nts>ts_last):
54+
if nts > ts_last:
5255
too_young.append((nix, nts))
53-
elif (nts < ts_first):
54-
too_old.append((nix,nts))
56+
elif nts < ts_first:
57+
too_old.append((nix, nts))
5558
else:
56-
closest = (np.abs(old_timestamps - nts)).argmin()
57-
source.append(closest)
58-
target.append(nix)
59-
# check the edge cases,
60-
for nix, nts in reversed(too_old):
61-
closest = (np.abs(old_timestamps - nts)).argmin()
62-
if (closest not in source):
59+
closest = np.abs(old_timestamps - nts).argmin()
60+
if closest not in source: # Ensure unique mapping
61+
source.append(closest)
62+
target.append(nix)
63+
else:
64+
raise RuntimeError(
65+
f"Non-unique mapping. Closest old timestamp for {new_timestamps[nix]} is {old_timestamps[closest]} but that one was already assigned to {new_timestamps[source.index(closest)]}"
66+
)
67+
68+
# Handle too old timestamps (those before the first old timestamp)
69+
for nix, nts in too_old:
70+
closest = 0 # Assign to the first timestamp
71+
if closest not in source: # Ensure unique mapping
6372
source.append(closest)
6473
target.append(nix)
65-
break
74+
break # only one, because we only need the edge
75+
76+
# Handle too young timestamps (those after the last old timestamp)
6677
for nix, nts in too_young:
67-
closest = (np.abs(old_timestamps - nts)).argmin()
68-
if (closest not in source):
78+
closest = len(old_timestamps) - 1 # Assign to the last timestamp
79+
if closest not in source: # Ensure unique mapping
6980
source.append(closest)
7081
target.append(nix)
71-
break
72-
73-
if len(set(source)) != len(old_timestamps):
74-
missed = len(old_timestamps)-len(set(source))
75-
raise RuntimeError(f"Too few new timestamps. {missed} of {len(old_timestamps)} old samples could not be assigned.")
76-
if len(set(source)) != len(source): #non-unique mapping
77-
cnt = Counter(source)
78-
toomany = defaultdict(list)
79-
for v,n in zip(source, target):
80-
if cnt[v] != 1:
81-
toomany[old_timestamps[source[v]]].append(new_timestamps[target[n]])
82-
for k,v in toomany.items():
83-
print("The old time_stamp ", k,
84-
"is a closest neighbor of", len(v) ,"new time_stamps:", v)
85-
raise RuntimeError("Can not align streams. Could not create an unique mapping")
82+
break # only one, because we only need the edge
83+
84+
# Sanity check: all old timestamps should be assigned to at least one new timestamp
85+
missed = len(old_timestamps) - len(set(source))
86+
if missed > 0:
87+
unassigned_old = [i for i in range(len(old_timestamps)) if i not in source]
88+
raise RuntimeError(
89+
f"Too few new timestamps. {missed} old timestamps ({old_timestamps[unassigned_old]}) found no corresponding new timestamp because it was already taken by another old timestamp. If your stream has multiple segments, this might be caused by small differences in effective srate between segments. Try different dejittering thresholds or support your own aligned_timestamps."
90+
)
91+
92+
# Populate new timeseries with aligned values from old_timeseries
8693
for chan in range(old_timeseries.shape[1]):
87-
new_timeseries[target, chan] = old_timeseries[source,chan]
94+
new_timeseries[target, chan] = old_timeseries[source, chan]
95+
8896
return new_timeseries
8997

9098

91-
def align_streams(streams, # List[defaultdict]
92-
align_foo=dict(), # defaultdict[int, Callable]
93-
aligned_timestamps=None, # Optional[List[float]]
94-
sampling_rate=None # Optional[float|int]
95-
): # -> Tuple[np.ndarray, List[float]]
99+
def align_streams(
100+
streams, # List[defaultdict]
101+
align_foo=dict(), # defaultdict[int, Callable]
102+
aligned_timestamps=None, # Optional[List[float]]
103+
sampling_rate=None, # Optional[float|int]
104+
): # -> Tuple[np.ndarray, List[float]]
96105
"""
97-
A function to
106+
A function to
98107
99108
100109
Args:
101110
102-
streams: a list of defaultdicts (i.e. streams) as returned by
111+
streams: a list of defaultdicts (i.e. streams) as returned by
103112
load_xdf
104-
align_foo: a dictionary mapping streamIDs (i.e. int) to interpolation
105-
callables. These callables must have the signature
113+
align_foo: a dictionary mapping streamIDs (i.e. int) to interpolation
114+
callables. These callables must have the signature
106115
`interpolate(old_timestamps, old_timeseries, new_timestamps)` and return a np.ndarray. See `_shift_align` and `_interpolate` for examples.
107-
aligned_timestamps (optional): a list of floats with the new
116+
aligned_timestamps (optional): a list of floats with the new
108117
timestamps to be used for alignment/interpolation. This list of timestamps can be irregular and have gaps.
109-
sampling_rate (optional): a float defining the sampling rate which
118+
sampling_rate (optional): a float defining the sampling rate which
110119
will be used to calculate aligned_timestamps.
111-
120+
112121
Return:
113122
(aligned_timeseries, aligned_timestamps): tuple
114123
115124
116-
THe user can define either aligned_timestamps or sampling_rate or neither. If neither is defined, the algorithm will take the sampling_rate of the fastest stream and create aligned_timestamps from the oldest sample of all streams to the youngest.
117-
125+
THe user can define either aligned_timestamps or sampling_rate or neither. If neither is defined, the algorithm will take the sampling_rate of the fastest stream and create aligned_timestamps from the oldest sample of all streams to the youngest.
126+
118127
"""
119-
128+
120129
if sampling_rate is not None and aligned_timestamps is not None:
121-
raise ValueError("You can not specify aligned_timestamps and sampling_rate at the same time")
122-
130+
raise ValueError(
131+
"You can not specify aligned_timestamps and sampling_rate at the same time"
132+
)
133+
123134
if sampling_rate is None:
124-
# we pick the effective sampling rate from the fastest stream
135+
# we pick the effective sampling rate from the fastest stream
125136
srates = [stream["info"]["effective_srate"] for stream in streams]
126137
sampling_rate = max(srates, default=0)
127138
if sampling_rate <= 0: # either no valid stream or all streams are async
128-
warnings.warn("Can not align streams: Fastest effective sampling rate was 0 or smaller.")
139+
warnings.warn(
140+
"Can not align streams: Fastest effective sampling rate was 0 step = 1 / sampling_rateor smaller."
141+
)
129142
return streams
130-
131-
132-
if aligned_timestamps is None:
143+
144+
if aligned_timestamps is None:
133145
# we pick the oldest and youngest timestamp of all streams
134-
stamps = [stream["time_stamps"] for stream in streams]
135-
ts_first = min((min(s) for s in stamps))
136-
ts_last = max((max(s) for s in stamps))
137-
full_dur = ts_last-ts_first
138-
step = 1/sampling_rate
146+
stamps = [stream["time_stamps"] for stream in streams]
147+
ts_first = min((min(s) for s in stamps))
148+
ts_last = max((max(s) for s in stamps))
149+
full_dur = ts_last - ts_first
150+
# Use np.linspace for precise control over the number of points and guaranteed inclusion of the stop value.
151+
# np.arange is better when you need direct control over step size but may exclude the stop value and accumulate floating-point errors.
152+
# Choose np.linspace for better precision and np.arange for efficiency with fixed steps.
139153
# we create new regularized timestamps
140-
aligned_timestamps = np.arange(ts_first, ts_last+step/2, step)
141-
# using np.linspace only differs in step if n_samples is different (as n_samples must be an integer number (see implementation below).
142-
# therefore we stick with np.arange (in spite of possible floating point error accumulation, but to make sure that ts_last is included, we add a half-step. This therefore comes at the cost of a overshoot, but i consider this acceptable considering this stamp would only be from one stream, and not part of all other and therefore is kind of arbitray anyways.
154+
# arange implementation:
155+
# step = 1 / sampling_rate
156+
# aligned_timestamps = np.arange(ts_first, ts_last + step / 2, step)
143157
# linspace implementation:
144-
# n_samples = int(np.round((full_dur * sampling_rate),0))+1
145-
# aligned_timestamps = np.linspace(ts_first, ts_last, n_samples)
146-
158+
# add 1 to the number of samples to include the last sample
159+
n_samples = int(np.round((full_dur * sampling_rate), 0)) + 1
160+
aligned_timestamps = np.linspace(ts_first, ts_last, n_samples)
161+
147162
channels = 0
148163
for stream in streams:
149164
# print(stream)
150165
channels += int(stream["info"]["channel_count"][0])
151166
# https://stackoverflow.com/questions/1704823/create-numpy-matrix-filled-with-nans The timings show a preference for ndarray.fill(..) as the faster alternative.
152-
aligned_timeseries = np.empty((len(aligned_timestamps),
153-
channels,), dtype=object)
167+
aligned_timeseries = np.empty(
168+
(
169+
len(aligned_timestamps),
170+
channels,
171+
),
172+
dtype=object,
173+
)
154174
aligned_timeseries.fill(np.nan)
155175

156-
chan_start = 0
176+
chan_start = 0
157177
chan_end = 0
158178
for stream in streams:
159179
sid = stream["info"]["stream_id"]
160-
align = align_foo.get(sid, _shift_align)
180+
align = align_foo.get(sid, _shift_align)
161181
chan_cnt = int(stream["info"]["channel_count"][0])
162182
new_timeseries = np.empty((len(aligned_timestamps), chan_cnt), dtype=object)
163183
new_timeseries.fill(np.nan)
164-
for seg_start, seg_stop in stream["info"]["segments"]:
165-
_new_timeseries = align(
166-
stream["time_stamps"][seg_start:seg_stop+1],
167-
stream["time_series"][seg_start:seg_stop+1],
168-
aligned_timestamps)
184+
print("Stream #", sid, " has ", len(stream["info"]["segments"]), "segments")
185+
for seg_idx, (seg_start, seg_stop) in enumerate(stream["info"]["segments"]):
186+
print(seg_idx, ": from index ", seg_start, "to ", seg_stop + 1)
187+
# segments have been created including the stop index, so we need to add 1 to include the last sample
188+
segment_old_timestamps = stream["time_stamps"][seg_start : seg_stop + 1]
189+
segment_old_timeseries = stream["time_series"][seg_start : seg_stop + 1]
190+
# Sanity check for duplicate timestamps
191+
if len(np.unique(segment_old_timestamps)) != len(segment_old_timestamps):
192+
raise RuntimeError("Duplicate timestamps found in old_timestamps")
193+
# apply align function as defined by the user (or default)
194+
segment_new_timeseries = align(
195+
segment_old_timestamps,
196+
segment_old_timeseries,
197+
aligned_timestamps,
198+
)
169199
# pick indices of the NEW timestamps closest to when segments start and stop
170200
a = stream["time_stamps"][seg_start]
171201
b = stream["time_stamps"][seg_stop]
172-
aix = np.argmin(np.abs(aligned_timestamps-a))
173-
bix = np.argmin(np.abs(aligned_timestamps-b))
202+
aix = np.argmin(np.abs(aligned_timestamps - a))
203+
bix = np.argmin(np.abs(aligned_timestamps - b))
174204
# and store only this aligned segment, leaving the rest as nans (or aligned as other segments)
175-
new_timeseries[aix:bix+1] = _new_timeseries[aix:bix+1]
205+
new_timeseries[aix : bix + 1] = segment_new_timeseries[aix : bix + 1]
176206

177207
# store the new timeseries at the respective channel indices in the 2D array
178208
chan_start = chan_end

test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import matplotlib.pyplot as plt
2+
import pyxdf
3+
4+
if __name__ == "__main__":
5+
fname = "/home/rtgugg/Downloads/sub-13_ses-S001_task-HCT_run-001_eeg.xdf"
6+
# streams, header = pyxdf.load_xdf(
7+
# fname, select_streams=[2, 5]
8+
# ) # EEG and ACC streams
9+
10+
# pyxdf.align_streams(streams)
11+
12+
streams, header = pyxdf.load_xdf(fname, select_streams=[2]) # EEG stream
13+
plt.plot(streams[0]["time_stamps"])
14+
plt.show()
15+
pyxdf.align_streams(streams)

0 commit comments

Comments
 (0)