Skip to content

Commit 9f6538f

Browse files
committed
benchmark LRR on Apple Silicon
1 parent 8afc241 commit 9f6538f

1 file changed

Lines changed: 317 additions & 0 deletions

File tree

tests/benchmark/bench_lrr.py

Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
"""Performance benchmarks for LRRTransformer.
2+
3+
Run with:
4+
.venv/bin/python tests/benchmark/bench_lrr.py
5+
6+
Benchmarks:
7+
1. _process (inference) — numpy, varying chunk sizes
8+
2. partial_fit (training) — numpy, varying chunk sizes
9+
3. _process — torch MPS (Apple Silicon GPU)
10+
4. partial_fit — torch MPS (Apple Silicon GPU)
11+
"""
12+
13+
import time
14+
15+
import numpy as np
16+
from ezmsg.util.messages.axisarray import AxisArray
17+
18+
from ezmsg.learn.process.ssr import LRRSettings, LRRTransformer
19+
20+
# ---------------------------------------------------------------------------
21+
# Parameters
22+
# ---------------------------------------------------------------------------
23+
24+
N_CH = 512
25+
N_CLUSTERS = 8
26+
CLUSTER_SIZE = N_CH // N_CLUSTERS # 64
27+
FS = 30_000.0
28+
CHUNK_SIZES = [20, 50, 100, 150, 200, 300]
29+
WARMUP_ITERS = 20
30+
BENCH_ITERS = 200
31+
32+
CLUSTERS = [list(range(i * CLUSTER_SIZE, (i + 1) * CLUSTER_SIZE)) for i in range(N_CLUSTERS)]
33+
34+
35+
# ---------------------------------------------------------------------------
36+
# Helpers
37+
# ---------------------------------------------------------------------------
38+
39+
40+
def _make_msg(data, key: str = "bench") -> AxisArray:
41+
return AxisArray(
42+
data=data,
43+
dims=["time", "ch"],
44+
axes={"time": AxisArray.TimeAxis(fs=FS, offset=0.0)},
45+
key=key,
46+
)
47+
48+
49+
def _print_header(title: str) -> None:
50+
print(f"\n{'=' * 70}")
51+
print(f" {title}")
52+
print(f"{'=' * 70}")
53+
54+
55+
def _print_row(chunk: int, median_us: float, throughput_khz: float) -> None:
56+
print(f" chunk={chunk:>4d} | {median_us:8.1f} us/call | {throughput_khz:8.1f} kHz effective")
57+
58+
59+
def _bench_loop(fn, n_warmup: int, n_iters: int) -> list[float]:
60+
"""Run fn() for warmup + measured iterations, return list of elapsed times."""
61+
for _ in range(n_warmup):
62+
fn()
63+
times = []
64+
for _ in range(n_iters):
65+
t0 = time.perf_counter()
66+
fn()
67+
times.append(time.perf_counter() - t0)
68+
return times
69+
70+
71+
def _bench_loop_sync(fn, sync_fn, n_warmup: int, n_iters: int) -> list[float]:
72+
"""Like _bench_loop but calls sync_fn() before each timing measurement."""
73+
for _ in range(n_warmup):
74+
fn()
75+
sync_fn()
76+
times = []
77+
for _ in range(n_iters):
78+
t0 = time.perf_counter()
79+
fn()
80+
sync_fn()
81+
times.append(time.perf_counter() - t0)
82+
return times
83+
84+
85+
# ---------------------------------------------------------------------------
86+
# NumPy benchmarks
87+
# ---------------------------------------------------------------------------
88+
89+
90+
def bench_process_numpy() -> None:
91+
_print_header("_process (inference) — NumPy")
92+
print(f" {N_CH} channels, {N_CLUSTERS}x{CLUSTER_SIZE} clusters, {WARMUP_ITERS} warmup, {BENCH_ITERS} iters")
93+
print()
94+
95+
rng = np.random.default_rng(0)
96+
97+
# Fit via partial_fit so the message hash is primed for send()
98+
fit_data = rng.standard_normal((2000, N_CH))
99+
proc = LRRTransformer(LRRSettings(channel_clusters=CLUSTERS, min_cluster_size=1))
100+
proc.partial_fit(_make_msg(fit_data))
101+
102+
for chunk in CHUNK_SIZES:
103+
data = rng.standard_normal((chunk, N_CH))
104+
msg = _make_msg(data)
105+
# Prime — first send triggers the affine's _reset_state
106+
proc.send(msg)
107+
108+
times = _bench_loop(lambda: proc.send(msg), WARMUP_ITERS, BENCH_ITERS)
109+
median_us = np.median(times) * 1e6
110+
throughput = chunk / np.median(times) # samples/s
111+
_print_row(chunk, median_us, throughput / 1e3)
112+
113+
114+
def bench_partial_fit_numpy() -> None:
115+
_print_header("partial_fit (training) — NumPy")
116+
print(f" {N_CH} channels, {N_CLUSTERS}x{CLUSTER_SIZE} clusters, {WARMUP_ITERS} warmup, {BENCH_ITERS} iters")
117+
print()
118+
119+
rng = np.random.default_rng(1)
120+
proc = LRRTransformer(LRRSettings(channel_clusters=CLUSTERS, min_cluster_size=1))
121+
122+
for chunk in CHUNK_SIZES:
123+
data = rng.standard_normal((chunk, N_CH))
124+
msg = _make_msg(data)
125+
# Prime
126+
proc.partial_fit(msg)
127+
128+
times = _bench_loop(lambda: proc.partial_fit(msg), WARMUP_ITERS, BENCH_ITERS)
129+
median_us = np.median(times) * 1e6
130+
throughput = chunk / np.median(times)
131+
_print_row(chunk, median_us, throughput / 1e3)
132+
133+
134+
# ---------------------------------------------------------------------------
135+
# Torch MPS benchmarks
136+
# ---------------------------------------------------------------------------
137+
138+
139+
def bench_process_mps() -> None:
140+
import torch
141+
142+
if not torch.backends.mps.is_available():
143+
print("\n [SKIPPED] MPS not available")
144+
return
145+
146+
_print_header("_process (inference) — Torch MPS")
147+
print(f" {N_CH} channels, {N_CLUSTERS}x{CLUSTER_SIZE} clusters, {WARMUP_ITERS} warmup, {BENCH_ITERS} iters")
148+
print()
149+
150+
rng = np.random.default_rng(0)
151+
device = torch.device("mps")
152+
153+
# Fit on CPU (numpy), then send MPS data to trigger device conversion
154+
fit_data = rng.standard_normal((2000, N_CH))
155+
proc = LRRTransformer(LRRSettings(channel_clusters=CLUSTERS, min_cluster_size=1))
156+
proc.partial_fit(_make_msg(fit_data))
157+
158+
def sync():
159+
torch.mps.synchronize()
160+
161+
for chunk in CHUNK_SIZES:
162+
data_mps = torch.randn(chunk, N_CH, device=device, dtype=torch.float32)
163+
msg = _make_msg(data_mps)
164+
# Prime — first send triggers affine's _reset_state with device conversion
165+
proc.send(msg)
166+
167+
times = _bench_loop_sync(lambda: proc.send(msg), sync, WARMUP_ITERS, BENCH_ITERS)
168+
median_us = np.median(times) * 1e6
169+
throughput = chunk / np.median(times)
170+
_print_row(chunk, median_us, throughput / 1e3)
171+
172+
173+
def bench_partial_fit_mps() -> None:
174+
import torch
175+
176+
if not torch.backends.mps.is_available():
177+
print("\n [SKIPPED] MPS not available")
178+
return
179+
180+
_print_header("partial_fit (training) — Torch MPS")
181+
print(f" {N_CH} channels, {N_CLUSTERS}x{CLUSTER_SIZE} clusters, {WARMUP_ITERS} warmup, {BENCH_ITERS} iters")
182+
print()
183+
184+
_ = np.random.default_rng(1)
185+
device = torch.device("mps")
186+
proc = LRRTransformer(LRRSettings(channel_clusters=CLUSTERS, min_cluster_size=1))
187+
188+
def sync():
189+
torch.mps.synchronize()
190+
191+
for chunk in CHUNK_SIZES:
192+
data_mps = torch.randn(chunk, N_CH, device=device, dtype=torch.float32)
193+
msg = _make_msg(data_mps)
194+
# Prime
195+
proc.partial_fit(msg)
196+
197+
times = _bench_loop_sync(lambda: proc.partial_fit(msg), sync, WARMUP_ITERS, BENCH_ITERS)
198+
median_us = np.median(times) * 1e6
199+
throughput = chunk / np.median(times)
200+
_print_row(chunk, median_us, throughput / 1e3)
201+
202+
203+
# ---------------------------------------------------------------------------
204+
# MLX benchmarks
205+
# ---------------------------------------------------------------------------
206+
207+
208+
def bench_process_mlx() -> None:
209+
try:
210+
import mlx.core as mx
211+
except ImportError:
212+
print("\n [SKIPPED] MLX not installed")
213+
return
214+
215+
_print_header("_process (inference) — MLX")
216+
print(f" {N_CH} channels, {N_CLUSTERS}x{CLUSTER_SIZE} clusters, {WARMUP_ITERS} warmup, {BENCH_ITERS} iters")
217+
print()
218+
219+
rng = np.random.default_rng(0)
220+
221+
# Fit on CPU (numpy), then send MLX data
222+
fit_data = rng.standard_normal((2000, N_CH))
223+
proc = LRRTransformer(LRRSettings(channel_clusters=CLUSTERS, min_cluster_size=1))
224+
proc.partial_fit(_make_msg(fit_data))
225+
226+
def sync():
227+
mx.eval()
228+
229+
for chunk in CHUNK_SIZES:
230+
data_mlx = mx.random.normal(shape=(chunk, N_CH))
231+
msg = _make_msg(data_mlx)
232+
# Prime — first send triggers affine's _reset_state with MLX conversion
233+
out = proc.send(msg)
234+
mx.eval(out.data)
235+
236+
def run():
237+
out = proc.send(msg)
238+
mx.eval(out.data)
239+
240+
times = _bench_loop(run, WARMUP_ITERS, BENCH_ITERS)
241+
median_us = np.median(times) * 1e6
242+
throughput = chunk / np.median(times)
243+
_print_row(chunk, median_us, throughput / 1e3)
244+
245+
246+
def bench_partial_fit_mlx() -> None:
247+
try:
248+
import mlx.core as mx
249+
except ImportError:
250+
print("\n [SKIPPED] MLX not installed")
251+
return
252+
253+
_print_header("partial_fit (training) — MLX")
254+
print(f" {N_CH} channels, {N_CLUSTERS}x{CLUSTER_SIZE} clusters, {WARMUP_ITERS} warmup, {BENCH_ITERS} iters")
255+
# MLX linalg.inv doesn't support GPU yet; run inv on CPU stream
256+
print(" NOTE: linalg.inv runs on mx.cpu stream (GPU not supported)")
257+
print()
258+
259+
import mlx.core as mx
260+
261+
_ = np.random.default_rng(1)
262+
proc = LRRTransformer(LRRSettings(channel_clusters=CLUSTERS, min_cluster_size=1))
263+
264+
# Monkey-patch _solve_weights to use mx.cpu stream for inv
265+
original_solve = proc._solve_weights
266+
267+
def _solve_weights_cpu_inv(cxx):
268+
from array_api_compat import get_namespace
269+
270+
xp = get_namespace(cxx)
271+
# If this is MLX, we need to override linalg.inv
272+
if xp.__name__ == "mlx.core":
273+
orig_inv = mx.linalg.inv
274+
mx.linalg.inv = lambda a: orig_inv(a, stream=mx.cpu)
275+
try:
276+
return original_solve(cxx)
277+
finally:
278+
mx.linalg.inv = orig_inv
279+
return original_solve(cxx)
280+
281+
proc._solve_weights = _solve_weights_cpu_inv
282+
283+
for chunk in CHUNK_SIZES:
284+
data_mlx = mx.random.normal(shape=(chunk, N_CH))
285+
msg = _make_msg(data_mlx)
286+
# Prime
287+
proc.partial_fit(msg)
288+
289+
def run():
290+
proc.partial_fit(msg)
291+
mx.eval()
292+
293+
times = _bench_loop(run, WARMUP_ITERS, BENCH_ITERS)
294+
median_us = np.median(times) * 1e6
295+
throughput = chunk / np.median(times)
296+
_print_row(chunk, median_us, throughput / 1e3)
297+
298+
299+
# ---------------------------------------------------------------------------
300+
# Main
301+
# ---------------------------------------------------------------------------
302+
303+
if __name__ == "__main__":
304+
print(f"LRRTransformer benchmark: {N_CH} channels, {N_CLUSTERS} clusters of {CLUSTER_SIZE}, fs={FS / 1e3:.0f} kHz")
305+
306+
bench_process_numpy()
307+
bench_partial_fit_numpy()
308+
bench_process_mps()
309+
bench_partial_fit_mps()
310+
bench_process_mlx()
311+
bench_partial_fit_mlx()
312+
313+
print()
314+
realtime_budget_us = {c: c / FS * 1e6 for c in CHUNK_SIZES}
315+
print("Real-time budgets at 30 kHz:")
316+
for chunk, budget in realtime_budget_us.items():
317+
print(f" chunk={chunk:>4d} -> {budget:8.1f} us")

0 commit comments

Comments
 (0)