Skip to content

Commit 3d73e6d

Browse files
committed
Attempt to address xdf-modules#1
1 parent 8fa4645 commit 3d73e6d

3 files changed

Lines changed: 142 additions & 89 deletions

File tree

pyxdf/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
#
44
# License: BSD (2-clause)
55

6-
from pkg_resources import get_distribution, DistributionNotFound
76
try:
8-
__version__ = get_distribution(__name__).version
9-
except DistributionNotFound: # package is not installed
7+
from pkg_resources import get_distribution, DistributionNotFound
8+
9+
try:
10+
__version__ = get_distribution(__name__).version
11+
except DistributionNotFound: # package is not installed
12+
__version__ = None
13+
except ImportError: # pkg_resources is not available
1014
__version__ = None
1115
from .pyxdf import load_xdf, resolve_streams, match_streaminfos, align_streams
1216

pyxdf/align.py

Lines changed: 120 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -31,148 +31,182 @@ def _interpolate(
3131

3232

3333
def _shift_align(old_timestamps, old_timeseries, new_timestamps):
34+
# Convert inputs to numpy arrays
3435
old_timestamps = np.array(old_timestamps)
3536
old_timeseries = np.array(old_timeseries)
3637
new_timestamps = np.array(new_timestamps)
38+
3739
ts_last = old_timestamps[-1]
38-
ts_first = old_timestamps[0]
39-
source = list()
40-
target = list()
41-
new_timeseries = np.empty((
42-
new_timestamps.shape[0], # new sample count
43-
old_timeseries.shape[1], # old channel count
44-
), dtype=object)
45-
new_timeseries.fill(np.nan)
46-
too_old = list()
47-
too_young = list()
40+
ts_first = old_timestamps[0]
41+
42+
# Initialize variables
43+
source = []
44+
target = []
45+
46+
new_timeseries = np.full((new_timestamps.shape[0], old_timeseries.shape[1]), np.nan)
47+
48+
too_old = []
49+
too_young = []
50+
51+
# Loop through new timestamps to find the closest old timestamp
52+
# Handle timestamps outside of the segment (too young or too old) different from stamnps from within the segment
4853
for nix, nts in enumerate(new_timestamps):
49-
closest = (np.abs(old_timestamps - nts)).argmin()
50-
# remember the edge cases,
51-
if (nts>ts_last):
54+
if nts > ts_last:
5255
too_young.append((nix, nts))
53-
elif (nts < ts_first):
54-
too_old.append((nix,nts))
56+
elif nts < ts_first:
57+
too_old.append((nix, nts))
5558
else:
56-
closest = (np.abs(old_timestamps - nts)).argmin()
57-
source.append(closest)
58-
target.append(nix)
59-
# check the edge cases,
60-
for nix, nts in reversed(too_old):
61-
closest = (np.abs(old_timestamps - nts)).argmin()
62-
if (closest not in source):
59+
closest = np.abs(old_timestamps - nts).argmin()
60+
if closest not in source: # Ensure unique mapping
61+
source.append(closest)
62+
target.append(nix)
63+
else:
64+
raise RuntimeError(
65+
f"Non-unique mapping. Closest old timestamp for {new_timestamps[nix]} is {old_timestamps[closest]} but that one was already assigned to {new_timestamps[source.index(closest)]}"
66+
)
67+
68+
# Handle too old timestamps (those before the first old timestamp)
69+
for nix, nts in too_old:
70+
closest = 0 # Assign to the first timestamp
71+
if closest not in source: # Ensure unique mapping
6372
source.append(closest)
6473
target.append(nix)
65-
break
74+
break # only one, because we only need the edge
75+
76+
# Handle too young timestamps (those after the last old timestamp)
6677
for nix, nts in too_young:
67-
closest = (np.abs(old_timestamps - nts)).argmin()
68-
if (closest not in source):
78+
closest = len(old_timestamps) - 1 # Assign to the last timestamp
79+
if closest not in source: # Ensure unique mapping
6980
source.append(closest)
7081
target.append(nix)
71-
break
72-
73-
if len(set(source)) != len(old_timestamps):
74-
missed = len(old_timestamps)-len(set(source))
75-
raise RuntimeError(f"Too few new timestamps. {missed} of {len(old_timestamps)} old samples could not be assigned.")
76-
if len(set(source)) != len(source): #non-unique mapping
77-
cnt = Counter(source)
78-
toomany = defaultdict(list)
79-
for v,n in zip(source, target):
80-
if cnt[v] != 1:
81-
toomany[old_timestamps[source[v]]].append(new_timestamps[target[n]])
82-
for k,v in toomany.items():
83-
print("The old time_stamp ", k,
84-
"is a closest neighbor of", len(v) ,"new time_stamps:", v)
85-
raise RuntimeError("Can not align streams. Could not create an unique mapping")
82+
break # only one, because we only need the edge
83+
84+
# Sanity check: all old timestamps should be assigned to at least one new timestamp
85+
missed = len(old_timestamps) - len(set(source))
86+
if missed > 0:
87+
unassigned_old = [i for i in range(len(old_timestamps)) if i not in source]
88+
for i, ts in zip(unassigned_old, old_timestamps[unassigned_old]):
89+
print(
90+
f"Old timestamp {ts} was not assigned to any new timestamp. Closest new timestamp is {new_timestamps[np.abs(new_timestamps - ts).argmin()]}"
91+
)
92+
raise RuntimeError(
93+
f"Too few new timestamps. {missed} old timestamps ({old_timestamps[unassigned_old]}) found no corresponding new timestamp because it was already taken by another old timestamp."
94+
)
95+
96+
# Populate new timeseries with aligned values from old_timeseries
8697
for chan in range(old_timeseries.shape[1]):
87-
new_timeseries[target, chan] = old_timeseries[source,chan]
98+
new_timeseries[target, chan] = old_timeseries[source, chan]
99+
88100
return new_timeseries
89101

90102

91-
def align_streams(streams, # List[defaultdict]
92-
align_foo=dict(), # defaultdict[int, Callable]
93-
aligned_timestamps=None, # Optional[List[float]]
94-
sampling_rate=None # Optional[float|int]
95-
): # -> Tuple[np.ndarray, List[float]]
103+
def align_streams(
104+
streams, # List[defaultdict]
105+
align_foo=dict(), # defaultdict[int, Callable]
106+
aligned_timestamps=None, # Optional[List[float]]
107+
sampling_rate=None, # Optional[float|int]
108+
): # -> Tuple[np.ndarray, List[float]]
96109
"""
97-
A function to
110+
A function to
98111
99112
100113
Args:
101114
102-
streams: a list of defaultdicts (i.e. streams) as returned by
115+
streams: a list of defaultdicts (i.e. streams) as returned by
103116
load_xdf
104-
align_foo: a dictionary mapping streamIDs (i.e. int) to interpolation
105-
callables. These callables must have the signature
117+
align_foo: a dictionary mapping streamIDs (i.e. int) to interpolation
118+
callables. These callables must have the signature
106119
`interpolate(old_timestamps, old_timeseries, new_timestamps)` and return a np.ndarray. See `_shift_align` and `_interpolate` for examples.
107-
aligned_timestamps (optional): a list of floats with the new
120+
aligned_timestamps (optional): a list of floats with the new
108121
timestamps to be used for alignment/interpolation. This list of timestamps can be irregular and have gaps.
109-
sampling_rate (optional): a float defining the sampling rate which
122+
sampling_rate (optional): a float defining the sampling rate which
110123
will be used to calculate aligned_timestamps.
111-
124+
112125
Return:
113126
(aligned_timeseries, aligned_timestamps): tuple
114127
115128
116-
THe user can define either aligned_timestamps or sampling_rate or neither. If neither is defined, the algorithm will take the sampling_rate of the fastest stream and create aligned_timestamps from the oldest sample of all streams to the youngest.
117-
129+
THe user can define either aligned_timestamps or sampling_rate or neither. If neither is defined, the algorithm will take the sampling_rate of the fastest stream and create aligned_timestamps from the oldest sample of all streams to the youngest.
130+
118131
"""
119-
132+
120133
if sampling_rate is not None and aligned_timestamps is not None:
121-
raise ValueError("You can not specify aligned_timestamps and sampling_rate at the same time")
122-
134+
raise ValueError(
135+
"You can not specify aligned_timestamps and sampling_rate at the same time"
136+
)
137+
123138
if sampling_rate is None:
124-
# we pick the effective sampling rate from the fastest stream
139+
# we pick the effective sampling rate from the fastest stream
125140
srates = [stream["info"]["effective_srate"] for stream in streams]
126141
sampling_rate = max(srates, default=0)
127142
if sampling_rate <= 0: # either no valid stream or all streams are async
128-
warnings.warn("Can not align streams: Fastest effective sampling rate was 0 or smaller.")
143+
warnings.warn(
144+
"Can not align streams: Fastest effective sampling rate was 0 step = 1 / sampling_rateor smaller."
145+
)
129146
return streams
130-
131-
132-
if aligned_timestamps is None:
147+
148+
if aligned_timestamps is None:
133149
# we pick the oldest and youngest timestamp of all streams
134-
stamps = [stream["time_stamps"] for stream in streams]
135-
ts_first = min((min(s) for s in stamps))
136-
ts_last = max((max(s) for s in stamps))
137-
full_dur = ts_last-ts_first
138-
step = 1/sampling_rate
150+
stamps = [stream["time_stamps"] for stream in streams]
151+
ts_first = min((min(s) for s in stamps))
152+
ts_last = max((max(s) for s in stamps))
153+
full_dur = ts_last - ts_first
154+
# Use np.linspace for precise control over the number of points and guaranteed inclusion of the stop value.
155+
# np.arange is better when you need direct control over step size but may exclude the stop value and accumulate floating-point errors.
156+
# Choose np.linspace for better precision and np.arange for efficiency with fixed steps.
139157
# we create new regularized timestamps
140-
aligned_timestamps = np.arange(ts_first, ts_last+step/2, step)
141-
# using np.linspace only differs in step if n_samples is different (as n_samples must be an integer number (see implementation below).
142-
# therefore we stick with np.arange (in spite of possible floating point error accumulation, but to make sure that ts_last is included, we add a half-step. This therefore comes at the cost of a overshoot, but i consider this acceptable considering this stamp would only be from one stream, and not part of all other and therefore is kind of arbitray anyways.
158+
# arange implementation:
159+
# step = 1 / sampling_rate
160+
# aligned_timestamps = np.arange(ts_first, ts_last + step / 2, step)
143161
# linspace implementation:
144-
# n_samples = int(np.round((full_dur * sampling_rate),0))+1
145-
# aligned_timestamps = np.linspace(ts_first, ts_last, n_samples)
146-
162+
# add 1 to the number of samples to include the last sample
163+
n_samples = int(np.round((full_dur * sampling_rate), 0)) + 1
164+
aligned_timestamps = np.linspace(ts_first, ts_last, n_samples)
165+
147166
channels = 0
148167
for stream in streams:
149168
# print(stream)
150169
channels += int(stream["info"]["channel_count"][0])
151170
# https://stackoverflow.com/questions/1704823/create-numpy-matrix-filled-with-nans The timings show a preference for ndarray.fill(..) as the faster alternative.
152-
aligned_timeseries = np.empty((len(aligned_timestamps),
153-
channels,), dtype=object)
171+
aligned_timeseries = np.empty(
172+
(
173+
len(aligned_timestamps),
174+
channels,
175+
),
176+
dtype=object,
177+
)
154178
aligned_timeseries.fill(np.nan)
155179

156-
chan_start = 0
180+
chan_start = 0
157181
chan_end = 0
158182
for stream in streams:
159183
sid = stream["info"]["stream_id"]
160-
align = align_foo.get(sid, _shift_align)
184+
align = align_foo.get(sid, _shift_align)
161185
chan_cnt = int(stream["info"]["channel_count"][0])
162186
new_timeseries = np.empty((len(aligned_timestamps), chan_cnt), dtype=object)
163187
new_timeseries.fill(np.nan)
164-
for seg_start, seg_stop in stream["info"]["segments"]:
165-
_new_timeseries = align(
166-
stream["time_stamps"][seg_start:seg_stop+1],
167-
stream["time_series"][seg_start:seg_stop+1],
168-
aligned_timestamps)
188+
print("Stream #", sid, " has ", len(stream["info"]["segments"]), "segments")
189+
for seg_idx, (seg_start, seg_stop) in enumerate(stream["info"]["segments"]):
190+
print(seg_idx, ": from index ", seg_start, "to ", seg_stop + 1)
191+
# segments have been created including the stop index, so we need to add 1 to include the last sample
192+
segment_old_timestamps = stream["time_stamps"][seg_start : seg_stop + 1]
193+
segment_old_timeseries = stream["time_series"][seg_start : seg_stop + 1]
194+
# Sanity check for duplicate timestamps
195+
if len(np.unique(segment_old_timestamps)) != len(segment_old_timestamps):
196+
raise RuntimeError("Duplicate timestamps found in old_timestamps")
197+
# apply align function as defined by the user (or default)
198+
segment_new_timeseries = align(
199+
segment_old_timestamps,
200+
segment_old_timeseries,
201+
aligned_timestamps,
202+
)
169203
# pick indices of the NEW timestamps closest to when segments start and stop
170204
a = stream["time_stamps"][seg_start]
171205
b = stream["time_stamps"][seg_stop]
172-
aix = np.argmin(np.abs(aligned_timestamps-a))
173-
bix = np.argmin(np.abs(aligned_timestamps-b))
206+
aix = np.argmin(np.abs(aligned_timestamps - a))
207+
bix = np.argmin(np.abs(aligned_timestamps - b))
174208
# and store only this aligned segment, leaving the rest as nans (or aligned as other segments)
175-
new_timeseries[aix:bix+1] = _new_timeseries[aix:bix+1]
209+
new_timeseries[aix : bix + 1] = segment_new_timeseries[aix : bix + 1]
176210

177211
# store the new timeseries at the respective channel indices in the 2D array
178212
chan_start = chan_end

test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import matplotlib.pyplot as plt
2+
import pyxdf
3+
4+
if __name__ == "__main__":
5+
fname = "/home/rtgugg/Downloads/sub-13_ses-S001_task-HCT_run-001_eeg.xdf"
6+
# streams, header = pyxdf.load_xdf(
7+
# fname, select_streams=[2, 5]
8+
# ) # EEG and ACC streams
9+
10+
# pyxdf.align_streams(streams)
11+
12+
streams, header = pyxdf.load_xdf(fname, select_streams=[2]) # EEG stream
13+
plt.plot(streams[0]["time_stamps"])
14+
plt.show()
15+
pyxdf.align_streams(streams)

0 commit comments

Comments
 (0)