|
| 1 | +"""Generate synthetic and real HED strings/Series for benchmarking. |
| 2 | +
|
| 3 | +Usage:: |
| 4 | +
|
| 5 | + from data_generator import DataGenerator |
| 6 | + gen = DataGenerator() # loads schema 8.3.0 |
| 7 | + s = gen.make_string(n_tags=10, n_groups=2, depth=1) |
| 8 | + series = gen.make_series(n_rows=1000, n_tags=10, n_groups=2, depth=1) |
| 9 | + real = gen.load_real_data(tile_to=5000) |
| 10 | +""" |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +import os |
| 15 | +import random |
| 16 | + |
| 17 | +import pandas as pd |
| 18 | + |
| 19 | +from hed.schema import load_schema_version |
| 20 | +from hed.models.schema_lookup import generate_schema_lookup |
| 21 | +from hed.models.tabular_input import TabularInput |
| 22 | +from hed.models.df_util import convert_to_form |
| 23 | + |
| 24 | + |
| 25 | +class DataGenerator: |
| 26 | + """Build synthetic and real HED data for benchmarking.""" |
| 27 | + |
| 28 | + def __init__(self, schema_version="8.3.0", seed=42): |
| 29 | + self.schema = load_schema_version(schema_version) |
| 30 | + self.lookup = generate_schema_lookup(self.schema) |
| 31 | + self._rng = random.Random(seed) |
| 32 | + |
| 33 | + # Collect real tag short names from the schema for realistic generation |
| 34 | + self._all_tags = [] |
| 35 | + for name, entry in self.schema.tags.items(): |
| 36 | + if name.endswith("/#"): |
| 37 | + continue |
| 38 | + short = getattr(entry, "short_tag_name", name.rsplit("/", 1)[-1]) |
| 39 | + self._all_tags.append(short) |
| 40 | + |
| 41 | + # Separate leaf vs non-leaf for variety |
| 42 | + self._tags = list(self._all_tags) |
| 43 | + |
| 44 | + # ------------------------------------------------------------------ |
| 45 | + # Single string generation |
| 46 | + # ------------------------------------------------------------------ |
| 47 | + |
| 48 | + def _pick_tags(self, n, repeats=0): |
| 49 | + """Pick *n* unique tags, then append *repeats* duplicates of the first.""" |
| 50 | + chosen = self._rng.sample(self._tags, min(n, len(self._tags))) |
| 51 | + if repeats and chosen: |
| 52 | + chosen.extend([chosen[0]] * repeats) |
| 53 | + return chosen |
| 54 | + |
| 55 | + def make_string(self, n_tags=5, n_groups=0, depth=0, repeats=0, form="short"): |
| 56 | + """Build a single synthetic HED string. |
| 57 | +
|
| 58 | + Parameters: |
| 59 | + n_tags: Total number of tag tokens (spread across top-level and groups). |
| 60 | + n_groups: Number of parenthesised groups to create. |
| 61 | + depth: Maximum nesting depth inside groups. |
| 62 | + repeats: Number of duplicate copies of the first tag to append. |
| 63 | + form: 'short' | 'long' — tag form. |
| 64 | +
|
| 65 | + Returns: |
| 66 | + str: A raw HED string. |
| 67 | + """ |
| 68 | + tags = self._pick_tags(n_tags, repeats=repeats) |
| 69 | + if form == "long": |
| 70 | + tags = self._to_long(tags) |
| 71 | + |
| 72 | + if n_groups == 0 or depth == 0: |
| 73 | + return ", ".join(tags) |
| 74 | + |
| 75 | + # Distribute tags across top-level and groups |
| 76 | + top_count = max(1, n_tags - n_groups * 2) |
| 77 | + top_tags = tags[:top_count] |
| 78 | + remaining = tags[top_count:] |
| 79 | + |
| 80 | + parts = list(top_tags) |
| 81 | + for i in range(n_groups): |
| 82 | + group_tags = remaining[i * 2 : i * 2 + 2] if i * 2 + 2 <= len(remaining) else remaining[i * 2 :] |
| 83 | + if not group_tags: |
| 84 | + group_tags = [self._rng.choice(self._tags)] |
| 85 | + parts.append(self._wrap_group(group_tags, depth)) |
| 86 | + |
| 87 | + return ", ".join(parts) |
| 88 | + |
| 89 | + def _wrap_group(self, tags, depth): |
| 90 | + """Recursively nest *tags* to the given *depth*.""" |
| 91 | + inner = ", ".join(tags) |
| 92 | + result = f"({inner})" |
| 93 | + for _ in range(depth - 1): |
| 94 | + extra = self._rng.choice(self._tags) |
| 95 | + result = f"({extra}, {result})" |
| 96 | + return result |
| 97 | + |
| 98 | + def make_deeply_nested_string(self, depth, tags_per_level=2): |
| 99 | + """Build a string with deep nesting: (A, (B, (C, ...))). |
| 100 | +
|
| 101 | + Parameters: |
| 102 | + depth: Number of nesting levels. |
| 103 | + tags_per_level: Tags at each level. |
| 104 | +
|
| 105 | + Returns: |
| 106 | + str: Deeply nested HED string. |
| 107 | + """ |
| 108 | + tags = self._pick_tags(depth * tags_per_level + 2) |
| 109 | + # Build inside-out |
| 110 | + inner = ", ".join(tags[:tags_per_level]) |
| 111 | + for i in range(depth): |
| 112 | + level_tags = tags[tags_per_level + i * tags_per_level : tags_per_level + (i + 1) * tags_per_level] |
| 113 | + if not level_tags: |
| 114 | + level_tags = [self._rng.choice(self._tags)] |
| 115 | + inner = f"({', '.join(level_tags)}, ({inner}))" |
| 116 | + return f"Event, Action, {inner}" |
| 117 | + |
| 118 | + def make_string_with_specific_tags(self, target_tags, n_extra=5, n_groups=2, depth=1, repeats=0): |
| 119 | + """Build a string guaranteed to contain specific tags. |
| 120 | +
|
| 121 | + Parameters: |
| 122 | + target_tags: List of tag names to include. |
| 123 | + n_extra: Number of random extra tags. |
| 124 | + n_groups: Number of groups. |
| 125 | + depth: Nesting depth. |
| 126 | + repeats: How many times to repeat the first target tag. |
| 127 | +
|
| 128 | + Returns: |
| 129 | + str: HED string containing the target tags. |
| 130 | + """ |
| 131 | + extra = self._pick_tags(n_extra) |
| 132 | + all_tags = list(target_tags) + extra + [target_tags[0]] * repeats |
| 133 | + self._rng.shuffle(all_tags) |
| 134 | + |
| 135 | + if n_groups == 0 or depth == 0: |
| 136 | + return ", ".join(all_tags) |
| 137 | + |
| 138 | + top_count = max(1, len(all_tags) - n_groups * 2) |
| 139 | + top_tags = all_tags[:top_count] |
| 140 | + remaining = all_tags[top_count:] |
| 141 | + |
| 142 | + parts = list(top_tags) |
| 143 | + for i in range(n_groups): |
| 144 | + group_tags = remaining[i * 2 : i * 2 + 2] if i * 2 + 2 <= len(remaining) else remaining[i * 2 :] |
| 145 | + if not group_tags: |
| 146 | + group_tags = [self._rng.choice(self._tags)] |
| 147 | + parts.append(self._wrap_group(group_tags, depth)) |
| 148 | + |
| 149 | + return ", ".join(parts) |
| 150 | + |
| 151 | + def _to_long(self, short_tags): |
| 152 | + """Convert short tag names to long form via the schema.""" |
| 153 | + from hed.models.hed_tag import HedTag |
| 154 | + |
| 155 | + out = [] |
| 156 | + for t in short_tags: |
| 157 | + try: |
| 158 | + out.append(HedTag(t, self.schema).long_tag) |
| 159 | + except Exception: |
| 160 | + out.append(t) |
| 161 | + return out |
| 162 | + |
| 163 | + # ------------------------------------------------------------------ |
| 164 | + # Series generation |
| 165 | + # ------------------------------------------------------------------ |
| 166 | + |
| 167 | + def make_series(self, n_rows, *, n_tags=5, n_groups=0, depth=0, repeats=0, form="short", heterogeneous=False): |
| 168 | + """Build a pd.Series of HED strings. |
| 169 | +
|
| 170 | + Parameters: |
| 171 | + n_rows: Number of rows. |
| 172 | + n_tags, n_groups, depth, repeats, form: Passed to make_string. |
| 173 | + heterogeneous: If True, randomise parameters per row. |
| 174 | + """ |
| 175 | + if heterogeneous: |
| 176 | + rows = [] |
| 177 | + for _ in range(n_rows): |
| 178 | + nt = self._rng.choice([3, 5, 10, 15, 25]) |
| 179 | + ng = self._rng.choice([0, 1, 2, 5]) |
| 180 | + d = self._rng.choice([0, 1, 2]) |
| 181 | + rows.append(self.make_string(n_tags=nt, n_groups=ng, depth=d, form=form)) |
| 182 | + return pd.Series(rows) |
| 183 | + else: |
| 184 | + # Homogeneous: one template, tiled |
| 185 | + template = self.make_string(n_tags=n_tags, n_groups=n_groups, depth=depth, repeats=repeats, form=form) |
| 186 | + return pd.Series([template] * n_rows) |
| 187 | + |
| 188 | + # ------------------------------------------------------------------ |
| 189 | + # Real data |
| 190 | + # ------------------------------------------------------------------ |
| 191 | + |
| 192 | + def load_real_data(self, tile_to=None, form="short"): |
| 193 | + """Load the FacePerception BIDS events and return a HED Series. |
| 194 | +
|
| 195 | + Parameters: |
| 196 | + tile_to: If set, tile the series up to this many rows. |
| 197 | + form: 'short' | 'long'. |
| 198 | +
|
| 199 | + Returns: |
| 200 | + pd.Series of HED strings. |
| 201 | + """ |
| 202 | + bids_root = os.path.realpath( |
| 203 | + os.path.join(os.path.dirname(__file__), "..", "tests", "data", "bids_tests", "eeg_ds003645s_hed") |
| 204 | + ) |
| 205 | + sidecar = os.path.join(bids_root, "task-FacePerception_events.json") |
| 206 | + events = os.path.join(bids_root, "sub-002", "eeg", "sub-002_task-FacePerception_run-1_events.tsv") |
| 207 | + tab = TabularInput(events, sidecar) |
| 208 | + series = tab.series_filtered |
| 209 | + |
| 210 | + if form == "long": |
| 211 | + df = series.copy() |
| 212 | + convert_to_form(df, self.schema, "long_tag") |
| 213 | + series = df |
| 214 | + |
| 215 | + if tile_to and tile_to > len(series): |
| 216 | + reps = (tile_to // len(series)) + 1 |
| 217 | + series = pd.Series(list(series) * reps).iloc[:tile_to].reset_index(drop=True) |
| 218 | + |
| 219 | + return series |
| 220 | + |
| 221 | + |
| 222 | +# Quick self-test |
| 223 | +if __name__ == "__main__": |
| 224 | + gen = DataGenerator() |
| 225 | + print(f"Schema tags available: {len(gen._tags)}") |
| 226 | + print(f"Sample string (5 tags): {gen.make_string(5)}") |
| 227 | + print(f"Sample string (10 tags, 2 groups, depth 2): {gen.make_string(10, 2, 2)}") |
| 228 | + print(f"Sample string (5 tags, 3 repeats): {gen.make_string(5, repeats=3)}") |
| 229 | + print(f"Real data rows: {len(gen.load_real_data())}") |
| 230 | + print(f"Tiled to 500: {len(gen.load_real_data(tile_to=500))}") |
0 commit comments