Skip to content

speed up perform_gsn: device-aware fast backend + numerical fixes#27

Open
jacob-prince wants to merge 24 commits into
mainfrom
speedup-perform-gsn
Open

speed up perform_gsn: device-aware fast backend + numerical fixes#27
jacob-prince wants to merge 24 commits into
mainfrom
speedup-perform-gsn

Conversation

@jacob-prince

Copy link
Copy Markdown
Collaborator

summary

A faster, optionally GPU-accelerated GSN with NaN / per-unit-missing data
support, plus several numerical-correctness fixes to the shrinkage path.
The numpy CPU path stays the default; torch (cuda/mps) is optional + opt-in.

Paired with the PSN refactor-dec25 PR, which now routes perform_gsn
through this package and so depends on this branch.

what's new

fast backendwhy: GSN's covariance + cross-validated shrinkage estimation
is the runtime bottleneck at large nunits (thousands–tens of thousands of units),
and the original path couldn't handle missing data efficiently.

  • fast_perform_gsn.py — device-native (numpy + optional torch) GSN. Pushes the
    GEMM/solve-heavy work to torch (CPU or GPU) for large speedups; nan-aware
    uneven-trials path; opt['returns'] selector + opt-in eigvecs/eigvals returns
    (opt['eigh_device']) so callers (PSN) can skip re-doing the eigendecomposition;
    large-N memory cleanup.
  • batched_nll.py — batched-Cholesky shrinkage-NLL evaluation (optional torch).
    why: scoring all candidate shrinkage levels at once is far faster, but the full
    stack blows device memory at large N — so it's chunked.
  • missing_units.py — per-unit missing-data GSN (numpy + torch/gpu). why: real
    datasets have units missing on some conditions, not just whole-trial NaNs.
  • perform_gsn.py — wire in the fast path / options.

numerical fixes (shrinkage / covariance)why: use decompositions matched to
the matrices (symmetric / PSD) — faster and more numerically stable — and remove a
regularizer that was perturbing results.

  • calc_shrunken_covariance: remove the 1e-6*I ridge (+ an equivalence test that
    exercises the prior bug).
  • construct_nearest_psd_covariance: eigh instead of svd; matlab
    constructnearestpsdcovariance.m: eig instead of svd.
  • calc_mv_gaussian_pdf: solve_triangular instead of pinv.

testswhy: prove the fast/GPU paths stay equivalent to the reference and guard
the new edge cases.

  • new: gpu edge cases, python speedup equivalence, missing-units, speedup-magnitude
    benchmark, uneven-trials, bugfix tests; extended matlab↔python equivalence harness.

packagingwhy: torch is heavy, so keep it optional; don't drag matplotlib into
every import.

  • optional torch dependency (setup.py / requirements.txt); lazy matplotlib import
    in rsa_noise_ceiling.

review notes

  • numpy path is the default and unchanged in spirit; torch is opt-in via device.
  • a GPU eigendecomposition can pick a different basis on degenerate eigenspaces, so
    a few downstream values differ slightly from a CPU run — inherent backend behavior,
    not a bug (GPU test tolerances are relative, with an argmin-agreement check).

testing

  • pytest tests/
  • matlab↔python equivalence: tests/test_gsn_matlab_python_equivalence.sh

Adds test10_ridge_2d_data_cov as a red test for a longstanding bug in
calc_shrunken_covariance.py: lines 179 and 182 add c + 1e-6*I to a
rank-deficient training covariance, masking the genuine singularity at
alpha = 1. MATLAB's cholcov fails there -> nll = NaN -> min() skips
and picks the largest grid point below 1 (0.98 on linspace(0, 1, 51)).
Python's ridge lets alpha = 1 pass; when validation data lives in the
same low-rank subspace as training, a spurious very-negative log-det
dominates and Python picks shrinklevelD = 1.0.

The default gsn.simulate_data generator produces full-rank populations
that mask the bug, so an optional low_rank_spec field is added to
TEST_DEFS to switch in a custom rank-deficient signal+noise generator.
Test 10 (nvox=50 ncond=40 ntrial=3 rank_signal=5 rank_noise=10) fails
on the unfixed code with shrinklevelD Python=1.0 vs MATLAB=0.98 and
cSb/cS/cNb diverging at max abs err ~2e-3. The next commit removes the
two c + np.eye(...) * 1e-6 lines.
Deletes the rank check and c + np.eye(c.shape[0]) * 1e-6 block in the
2D path. There is no counterpart in the MATLAB original
calcshrunkencovariance.m. The shrinkage formula c2 = alpha*c +
(1-alpha)*diag(c) already keeps c2 non-singular for any alpha < 1 via
the diagonal injection; at alpha = 1, c2 equals the raw sample c and
if that is singular, cholcov fails naturally, nll(p) is set to NaN,
and nanargmin skips it. The Python ridge silently regularized the
input and let alpha = 1 falsely pass the singularity check, which
made Python pick shrinklevelD = 1.0 where MATLAB picks the largest
grid point below 1 (0.98 on linspace(0, 1, 51)) for data whose
population covariance is genuinely low-rank.

Also drops two unused imports (math, scipy.stats as stats) at the
top of the file.

Test 10 (test10_ridge_2d_data_cov) added in the previous commit now
passes; full equivalence suite is 10/10 and Python pytests are 61/61.
The Cholesky factor T is upper-triangular by construction, so
np.linalg.pinv(T) ran a full SVD on a triangular matrix — same
answer at O(N^3) where a triangular solve gives the same result at
O(N^2) per RHS. MATLAB's calcmvgaussianpdf.m already uses `pts / T`
(mrdivide), which dispatches to a triangular solve. Switches to
scipy.linalg.solve_triangular(T, pts.T, lower=False, trans='T').T,
which solves T.T @ X = pts.T for X = pts @ inv(T).
The input is symmetric (we symmetrize it on line 47), so eigh is the
right tool. Mathematically equivalent to the SVD-based path used in
the MATLAB original: for symmetric M with eigendecomposition V*D*V',
SVD gives U=V*sign(D), S=|D|, V_svd=V, so (M + V*S*V')/2 reduces to
V*max(D,0)*V' — exactly what eigh + clamp does directly. SVD does
~2x the work of eigh on symmetric input.

The eig fallback that the reference used when SVD failed to converge
is also dropped — eigh on a symmetric matrix is more numerically
stable than eig, so the SVD-failure escape path is no longer needed.
Used only inside the figure block. Moving the three matplotlib imports
inside that block saves ~400ms on every `import gsn` for callers who
never draw (e.g. perform_gsn / mode=1).
calc_shrunken_covariance picked the optimal shrinkage level by running
the held-out Gaussian NLL through a Python for-loop over all 51
shrinkage levels — each iteration its own O(N^3) Cholesky and O(M*N^2)
triangular solve. For N in the hundreds-to-thousands range that loop
is the dominant cost in perform_gsn.

The new gsn.batched_nll.batched_shrunken_nll collapses those 51
sequential factorizations into a single batched torch.linalg.cholesky_ex
plus a single batched solve_triangular over the (S, N, N) stack of
shrunken covariances. cholesky_ex returns per-slot status without
raising, so singular slots cleanly map to nll = NaN (matching MATLAB's
min(nll) skip-NaN behavior). When torch is absent we fall back to a
numpy + scipy loop that is bit-equivalent to the reference, just with
the mean-subtraction lifted out of the loop (it was invariant across
levels anyway).

calc_shrunken_covariance is refactored to build pts_zm once before the
loop and call the batched helper. The (N, N, S) covs array that the
reference materialized just to index `covs[:,:,min0ix]` at the end is
also dropped — for the wantfull=0 path we recompute the chosen
shrunken cov on demand, avoiding the 51x memory blowup at large N.

Adds matplotlib to requirements.txt (already used in the figure path,
just wasn't declared) and registers torch>=2.0 as an optional 'fast'
extra in setup.py — `pip install gsn[fast]` lights up the batched path
with no code changes at call sites.

Measured speedup of batched_shrunken_nll alone (numpy vs torch CPU):
N=200 2.9x, N=500 2.1x, N=1000 11.0x. Larger N widens the gap further.
Equivalence between paths is at floating-point noise (max|Δnll| ~5e-13).
batched_shrunken_nll now accepts device in {'cpu', 'cuda', 'mps', 'auto'}.
'cpu' is the default (unchanged behavior) and the right choice for
N up to ~1000 because GPU host<->device transfer costs more than the
batched cholesky_ex saves at that size. 'cuda' / 'mps' open up the GPU
path for large N; 'auto' picks cuda > mps > cpu based on availability.
_resolve_device errors clearly if the caller asks for a device this
torch install can't reach (better than letting it fail deep in a kernel
call). On mps we force float32 since Apple Metal has no float64.

calc_shrunken_covariance gains a device kwarg and threads it through.
rsa_noise_ceiling reads opt['device'] (defaults to 'cpu') and passes
it to both calc_shrunken_covariance calls. From a user perspective:

    perform_gsn(data, {'device': 'cuda'})    # explicit
    perform_gsn(data, {'device': 'auto'})    # cuda > mps > cpu
    perform_gsn(data)                        # default 'cpu'

Docstrings updated in perform_gsn, rsa_noise_ceiling, and
calc_shrunken_covariance.
tests/test_gsn_python_speedups.py is a pure-Python pytest suite (no
MATLAB required) covering every change made on this branch:

- calc_mv_gaussian_pdf (pinv -> solve_triangular): parametric N/M
  matches a direct log-density formula, wantomitexp flag, singular
  cov returns err=1, single-variable case, no-input-mutation.
- construct_nearest_psd_covariance (svd -> eigh): already-PSD
  passthrough, indefinite -> PSD projection, eigh vs SVD parity on
  symmetric input, asymmetric symmetrization, scalar/1x1/all-negative
  edge cases.
- calc_shrunken_covariance (ridge removal): regression test that
  full-rank data is unaffected, rank-deficient 2D data picks
  shrinklevelD < 1.0 (the bug fix), 3D path still works.
- rsa_noise_ceiling (lazy matplotlib): subprocess test that
  `import gsn` does not load matplotlib.pyplot.
- batched_nll (new module): torch vs numpy parity at three shapes,
  singular slots -> NaN, all-singular -> all-NaN, N=1, S=1, float32
  input, nanargmin picks the same level across paths.
- device dispatch: 'cpu' / 'auto' resolution, unavailable cuda / mps
  raise clean RuntimeError, cpu and auto produce identical results.
- perform_gsn integration: basic call, rank-deficient regression,
  torch vs numpy end-to-end equivalence, opt['device'] threading,
  determinism across repeated calls, uneven-trials path intact.

42 tests, all passing.
Mirror of the Python-side svd -> eigh change. The input is symmetric
(we symmetrize on line 26), so eig is the right tool: for symmetric M
with eigendecomposition V*D*V', the SVD-based form (M + V*|D|*V')/2
simplifies to V*max(D,0)*V' — exactly what we now do here directly.
~1.4-1.5x cheaper per call than svd on symmetric input.

The eig fallback that the old code used when svd failed to converge
is also dropped — eig on a (symmetrized) matrix is numerically
robust enough that the svd-failure escape path is no longer needed.

Equivalence tests Python<->MATLAB remain 10/10 with both languages on
the eig path; floating-point reordering between LAPACK paths kept the
diffs at machine precision.
Sweeps nunits and times perform_gsn / performgsn with K repeats per
cell. Auto-detects available backends (python-numpy, python-torch-cpu,
python-torch-cuda, python-torch-mps, matlab) and renders a 3-panel
figure: absolute runtime, power-law extrapolation to N=1e6 (fit on
N > 1000 only — small N is overhead-dominated), and relative speedup
vs python-main-reference. Outputs gitignored.
cluster/: rsync code, SLURM array job (one H100 per nunits), Python
driver times perform_gsn across python-numpy / python-torch-cpu /
python-torch-cuda. Shards merge to one JSON consumable by
tests/test_speedup_magnitude.py.

tests/test_gsn_gpu_edge_cases.py: 19 skip-able CUDA/MPS tests for the
gsn.batched_nll torch path — gpu↔cpu NLL parity, per-slot NaN under
cholesky_ex, dtype handling, consecutive-call independence, end-to-end
opt['device'] threading.
cluster/ is user/machine-specific (SLURM, conda paths, lab storage
locations); gitignored and untracked here.

tests/test_gsn_gpu_edge_cases.py: replace absolute-only tolerances
with atol + rtol*|ref| so f32 NLLs at large N stay within precision;
add per-test argmin-agreement check — the invariant downstream code
actually depends on.
tests/test_gsn_gpu_edge_cases.py: replace absolute-only tolerances
with atol + rtol*|ref| so f32 NLLs at large N stay within precision;
add per-test argmin-agreement check — the invariant downstream code
actually depends on.
New gsn/fast_perform_gsn.py runs the full GSN pipeline — noise+data
covariance, held-out shrinkage selection, biconvex loop, ncsnr — on a
single backend (numpy or torch on CPU/CUDA/MPS) without round-tripping
through host memory between stages.

What changed vs. the previous calc_shrunken_covariance + rsa_noise_ceiling
flow:
  * Einsum for the 3D pooled noise covariance, in place of the per-condition
    np.cov loop. On torch this collapses ncond_train kernel launches into one.
  * Biconvex iteration stays on device. construct_nearest_psd_covariance
    previously used numpy.linalg even when torch was available, so cSb/cNb
    round-tripped host<->device every iteration; the new _nearest_psd is
    written against the active backend.
  * End-to-end on one backend. Data moves to device once at entry; we only
    materialize numpy at the very end when building the results dict.

perform_gsn becomes a thin defaults+dispatcher (no more mode/ncsims/wantfig
setup for the rsa_noise_ceiling indirection). Uneven trials still fall
through to rsa_noise_ceiling.

batched_nll._torch_dtype_for now accepts both numpy and torch dtypes so
fast_perform_gsn can hand a device tensor straight to _torch_batched
without a numpy round trip.

Local cpu wall-clock (mac): 2-3x faster than the previous fast path
across N=200/500/1000. MATLAB equivalence still 10/10; all unit tests
(103) pass.
Log-log axes, dashed extrapolation, time-reference lines, per-backend
power-law fits in the legend. Display names: numpy + scipy.linalg loop,
torch CPU (batched), torch CUDA (batched), gsn.perform_gsn (reference).
The (S, N, N) shrunken-cov tensor scales as S * N^2 * dtype-bytes.
With S=51, float64, N=20000 that's 160 GB — well past H100's 80 GB.
_pick_chunk_size picks the largest shrinkage-level chunk that fits in
~70% of free device memory (queried via torch.cuda.mem_get_info when
available, with safe fallbacks otherwise). The inner loop processes
chunks sequentially, deleting intermediate tensors before the next
chunk's allocations.

Verified bit-identical to single-pass when chunk_size=S (the common
case for N <= 3000); 103/103 unit tests still pass.
Previous 70% headroom triggered chunking at N=10000 f64 even though
single-pass would have fit. 95% leaves single-pass behavior intact
through ~N=8000 (any dtype) and ~N=10000 f64; chunking activates only
where it's truly required (N >= ~12000 f64 or ~N=15000 f32 on H100).
Motivated by running GSN at large nunits where the torch path OOMed
because the result dict carried all four (N, N) cov matrices through
host memory while biconvex was still holding device tensors, and
biconvex itself held more intermediates than necessary.

opt['returns'] selector: callers pick which of cN / cS / cNb / cSb
they actually need. Default still emits all four for backwards
compat. PSN-only workloads (cSb + cNb consumers) can drop cN / cS
and save 2 * N^2 * dtype-bytes of host memory.

Memory cleanup in _run_torch:
- cS is no longer materialized unless 'cS' in returns. ncsnr needs
  only its diagonal, computed directly from diag(cD) - diag(cN)/ntrial.
- _biconvex_torch no longer takes cS; derives the iter-0 anchor
  inline as cD - cN/ntrial.
- Intermediate data tensors freed as soon as their cov is built;
  torch.cuda.empty_cache() between the big allocations.

_flat_pearson replaces torch.corrcoef for the biconvex convergence
check. cuBLAS dot caps its input length at int32 (~2.1e9 elements);
at large N the corrcoef path internally builds an intermediate
(2, N^2) stack that trips the cap. _flat_pearson uses element-wise
mul + reduce-sum with f64 accumulators (int64 strides, no cap).

_nearest_psd_torch:
- f32 -> f64 upcast for eigh. cuSOLVER syevd is unreliable on
  near-singular f32; upcasting fixes spurious negative eigenvalues
  at large N.
- scipy.linalg.eigh CPU fallback if the device eigh raises. cuSOLVER
  syevd hits its workspace limit at very large N — the fallback
  keeps us going at the cost of one host round-trip.

batched_nll: build the (S, N, N) shrunk-cov stack in place by scaling
by alpha and then restoring the diagonal. Equivalent to the previous
alpha*c + (1-alpha)*diag(c) but peak transient drops from
~3*chunk*N^2 to ~1*chunk*N^2.

Test benchmark pins opt['returns'] to ['cSb', 'cNb'] so the cross-
backend wall-clock matches the legacy main reference (which doesn't
do the three extra eighs that the default 'returns' triggers).
replace the cpu-only _delegate_uneven fallback so uneven (nan-padded)
trial counts get the same torch/gpu speedups as even data.

- noise cov: per-condition pooled covariance over valid trials as one
  masked weighted gemm, with cv shrinkage selection.
- data cov: deterministic min-trial truncation, then the existing 2d cov.
- biconvex: add ntrialbc param (division uses ntrial=min(validcnt),
  coefficients use the average); default leaves the even path unchanged.
- routing honors opt['device']; opt['uneven']='reference' keeps the old
  rsa delegation as a parity oracle.

matches the reference rsa path to ~1e-14 and matches matlab on the
uneven equivalence fixtures; cuda-validated; even-path tests unchanged.
adds test_fast_uneven_matches_reference.
new opt['uneven']='missing' path for per-electrode/per-unit artifact
rejection, where a trial may have some units present and others missing
(standard gsn requires whole-trial validity and would discard good data).

- cN: average over conditions (>=2 shared-clean trials) of the UNBIASED
  pairwise covariance of each unit pair over their shared-clean trials
  (each pair centered on its shared trials; closed form via 3 masked gemms
  S2=Xm Xm^T, Si=Xm Mb^T, K=Mb Mb^T, cov=(S2-Si.Si^T/K)/(K-1)).
- cD: unbiased pairwise covariance across conditions of per-unit
  condition-means, over conditions where both units are defined.
- bias: cS = cD - cN (.) alpha, alpha[i,j] = avg_c n_ij,c/(n_i,c n_j,c);
  generalizes the scalar 1/ntrial and reduces to it when complete.
- biconvex: exact per-entry alpha in the cSb step, effective scalar ntrial
  in the regularizer coefficients.
- shrinkage levels selected on the complete-data subset, applied to the
  per-entry covs.

pairwise (not available-means) centering avoids the partial-overlap bias
k/(k-1)*(1-1/n_i-1/n_j+k/(n_i n_j)). pinned by: exact reduction to the even
path on complete data, brute-force pairwise references for cN/cD/alpha, and
a monte-carlo unbiasedness test (catches the available-means bias).
tests/test_missing_units.py (16 tests). numpy only; torch path to follow.
device-native version of the missing-units estimator. the per-condition
pairwise-covariance loop and biconvex run on the active device; shrinkage
levels are picked on the host (cheap scalars, reusing the numpy selectors)
and applied on device. routing: opt['uneven']='missing' uses torch when
available (honoring opt['device']), else numpy.

pairwise cov per block via the same closed form (S2, Si=Xm Mb^T, K=Mb Mb^T).
new test_torch_matches_numpy asserts torch(cpu) == numpy on cN/cS/cNb/cSb,
ncsnr, means, eigenvalues, shrink levels, and the signal subspace.
- batched_nll: clone the expanded diagonal RHS so the N==1 shrinkage NLL
  no longer hits a torch memory-alias error (also unbroke
  calc_shrunken_covariance and fast_perform_gsn on single-variable data).
- honor opt['shrinklevels'] everywhere via a centralized _get_shrinklevels
  (the fast/uneven/missing paths previously always used the default grid).
- uneven noise-cov: assert when the held-out validation split has no
  condition with >=2 valid trials, instead of silently returning shrink
  level 0 on an all-NaN nll (matches rsa_noise_ceiling).
- fast_perform_gsn / run_missing_units_numpy: guard data is 3D with >=2
  trials (matches the reference 'Number of trials must be at least 2').

tests/test_bugfixes.py covers all four.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant