Skip to content

Commit 8d2d707

Browse files
committed
Did a performance analysis of the search
1 parent 59f0224 commit 8d2d707

25 files changed

Lines changed: 8981 additions & 21 deletions

benchmarks/data_generator.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
"""Generate synthetic and real HED strings/Series for benchmarking.
2+
3+
Usage::
4+
5+
from data_generator import DataGenerator
6+
gen = DataGenerator() # loads schema 8.3.0
7+
s = gen.make_string(n_tags=10, n_groups=2, depth=1)
8+
series = gen.make_series(n_rows=1000, n_tags=10, n_groups=2, depth=1)
9+
real = gen.load_real_data(tile_to=5000)
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import os
15+
import random
16+
17+
import pandas as pd
18+
19+
from hed.schema import load_schema_version
20+
from hed.models.schema_lookup import generate_schema_lookup
21+
from hed.models.tabular_input import TabularInput
22+
from hed.models.df_util import convert_to_form
23+
24+
25+
class DataGenerator:
26+
"""Build synthetic and real HED data for benchmarking."""
27+
28+
def __init__(self, schema_version="8.3.0", seed=42):
29+
self.schema = load_schema_version(schema_version)
30+
self.lookup = generate_schema_lookup(self.schema)
31+
self._rng = random.Random(seed)
32+
33+
# Collect real tag short names from the schema for realistic generation
34+
self._all_tags = []
35+
for name, entry in self.schema.tags.items():
36+
if name.endswith("/#"):
37+
continue
38+
short = getattr(entry, "short_tag_name", name.rsplit("/", 1)[-1])
39+
self._all_tags.append(short)
40+
41+
# Separate leaf vs non-leaf for variety
42+
self._tags = list(self._all_tags)
43+
44+
# ------------------------------------------------------------------
45+
# Single string generation
46+
# ------------------------------------------------------------------
47+
48+
def _pick_tags(self, n, repeats=0):
49+
"""Pick *n* unique tags, then append *repeats* duplicates of the first."""
50+
chosen = self._rng.sample(self._tags, min(n, len(self._tags)))
51+
if repeats and chosen:
52+
chosen.extend([chosen[0]] * repeats)
53+
return chosen
54+
55+
def make_string(self, n_tags=5, n_groups=0, depth=0, repeats=0, form="short"):
56+
"""Build a single synthetic HED string.
57+
58+
Parameters:
59+
n_tags: Total number of tag tokens (spread across top-level and groups).
60+
n_groups: Number of parenthesised groups to create.
61+
depth: Maximum nesting depth inside groups.
62+
repeats: Number of duplicate copies of the first tag to append.
63+
form: 'short' | 'long' — tag form.
64+
65+
Returns:
66+
str: A raw HED string.
67+
"""
68+
tags = self._pick_tags(n_tags, repeats=repeats)
69+
if form == "long":
70+
tags = self._to_long(tags)
71+
72+
if n_groups == 0 or depth == 0:
73+
return ", ".join(tags)
74+
75+
# Distribute tags across top-level and groups
76+
top_count = max(1, n_tags - n_groups * 2)
77+
top_tags = tags[:top_count]
78+
remaining = tags[top_count:]
79+
80+
parts = list(top_tags)
81+
for i in range(n_groups):
82+
group_tags = remaining[i * 2 : i * 2 + 2] if i * 2 + 2 <= len(remaining) else remaining[i * 2 :]
83+
if not group_tags:
84+
group_tags = [self._rng.choice(self._tags)]
85+
parts.append(self._wrap_group(group_tags, depth))
86+
87+
return ", ".join(parts)
88+
89+
def _wrap_group(self, tags, depth):
90+
"""Recursively nest *tags* to the given *depth*."""
91+
inner = ", ".join(tags)
92+
result = f"({inner})"
93+
for _ in range(depth - 1):
94+
extra = self._rng.choice(self._tags)
95+
result = f"({extra}, {result})"
96+
return result
97+
98+
def make_deeply_nested_string(self, depth, tags_per_level=2):
99+
"""Build a string with deep nesting: (A, (B, (C, ...))).
100+
101+
Parameters:
102+
depth: Number of nesting levels.
103+
tags_per_level: Tags at each level.
104+
105+
Returns:
106+
str: Deeply nested HED string.
107+
"""
108+
tags = self._pick_tags(depth * tags_per_level + 2)
109+
# Build inside-out
110+
inner = ", ".join(tags[:tags_per_level])
111+
for i in range(depth):
112+
level_tags = tags[tags_per_level + i * tags_per_level : tags_per_level + (i + 1) * tags_per_level]
113+
if not level_tags:
114+
level_tags = [self._rng.choice(self._tags)]
115+
inner = f"({', '.join(level_tags)}, ({inner}))"
116+
return f"Event, Action, {inner}"
117+
118+
def make_string_with_specific_tags(self, target_tags, n_extra=5, n_groups=2, depth=1, repeats=0):
119+
"""Build a string guaranteed to contain specific tags.
120+
121+
Parameters:
122+
target_tags: List of tag names to include.
123+
n_extra: Number of random extra tags.
124+
n_groups: Number of groups.
125+
depth: Nesting depth.
126+
repeats: How many times to repeat the first target tag.
127+
128+
Returns:
129+
str: HED string containing the target tags.
130+
"""
131+
extra = self._pick_tags(n_extra)
132+
all_tags = list(target_tags) + extra + [target_tags[0]] * repeats
133+
self._rng.shuffle(all_tags)
134+
135+
if n_groups == 0 or depth == 0:
136+
return ", ".join(all_tags)
137+
138+
top_count = max(1, len(all_tags) - n_groups * 2)
139+
top_tags = all_tags[:top_count]
140+
remaining = all_tags[top_count:]
141+
142+
parts = list(top_tags)
143+
for i in range(n_groups):
144+
group_tags = remaining[i * 2 : i * 2 + 2] if i * 2 + 2 <= len(remaining) else remaining[i * 2 :]
145+
if not group_tags:
146+
group_tags = [self._rng.choice(self._tags)]
147+
parts.append(self._wrap_group(group_tags, depth))
148+
149+
return ", ".join(parts)
150+
151+
def _to_long(self, short_tags):
152+
"""Convert short tag names to long form via the schema."""
153+
from hed.models.hed_tag import HedTag
154+
155+
out = []
156+
for t in short_tags:
157+
try:
158+
out.append(HedTag(t, self.schema).long_tag)
159+
except Exception:
160+
out.append(t)
161+
return out
162+
163+
# ------------------------------------------------------------------
164+
# Series generation
165+
# ------------------------------------------------------------------
166+
167+
def make_series(self, n_rows, *, n_tags=5, n_groups=0, depth=0, repeats=0, form="short", heterogeneous=False):
168+
"""Build a pd.Series of HED strings.
169+
170+
Parameters:
171+
n_rows: Number of rows.
172+
n_tags, n_groups, depth, repeats, form: Passed to make_string.
173+
heterogeneous: If True, randomise parameters per row.
174+
"""
175+
if heterogeneous:
176+
rows = []
177+
for _ in range(n_rows):
178+
nt = self._rng.choice([3, 5, 10, 15, 25])
179+
ng = self._rng.choice([0, 1, 2, 5])
180+
d = self._rng.choice([0, 1, 2])
181+
rows.append(self.make_string(n_tags=nt, n_groups=ng, depth=d, form=form))
182+
return pd.Series(rows)
183+
else:
184+
# Homogeneous: one template, tiled
185+
template = self.make_string(n_tags=n_tags, n_groups=n_groups, depth=depth, repeats=repeats, form=form)
186+
return pd.Series([template] * n_rows)
187+
188+
# ------------------------------------------------------------------
189+
# Real data
190+
# ------------------------------------------------------------------
191+
192+
def load_real_data(self, tile_to=None, form="short"):
193+
"""Load the FacePerception BIDS events and return a HED Series.
194+
195+
Parameters:
196+
tile_to: If set, tile the series up to this many rows.
197+
form: 'short' | 'long'.
198+
199+
Returns:
200+
pd.Series of HED strings.
201+
"""
202+
bids_root = os.path.realpath(
203+
os.path.join(os.path.dirname(__file__), "..", "tests", "data", "bids_tests", "eeg_ds003645s_hed")
204+
)
205+
sidecar = os.path.join(bids_root, "task-FacePerception_events.json")
206+
events = os.path.join(bids_root, "sub-002", "eeg", "sub-002_task-FacePerception_run-1_events.tsv")
207+
tab = TabularInput(events, sidecar)
208+
series = tab.series_filtered
209+
210+
if form == "long":
211+
df = series.copy()
212+
convert_to_form(df, self.schema, "long_tag")
213+
series = df
214+
215+
if tile_to and tile_to > len(series):
216+
reps = (tile_to // len(series)) + 1
217+
series = pd.Series(list(series) * reps).iloc[:tile_to].reset_index(drop=True)
218+
219+
return series
220+
221+
222+
# Quick self-test
223+
if __name__ == "__main__":
224+
gen = DataGenerator()
225+
print(f"Schema tags available: {len(gen._tags)}")
226+
print(f"Sample string (5 tags): {gen.make_string(5)}")
227+
print(f"Sample string (10 tags, 2 groups, depth 2): {gen.make_string(10, 2, 2)}")
228+
print(f"Sample string (5 tags, 3 repeats): {gen.make_string(5, repeats=3)}")
229+
print(f"Real data rows: {len(gen.load_real_data())}")
230+
print(f"Tiled to 500: {len(gen.load_real_data(tile_to=500))}")
35.8 KB
Loading
112 KB
Loading
74.3 KB
Loading
141 KB
Loading
57.7 KB
Loading
74.6 KB
Loading
73.3 KB
Loading
65.8 KB
Loading
87.5 KB
Loading

0 commit comments

Comments
 (0)