Skip to content

Commit fe0eed2

Browse files
committed
replace numpy array with memmap file to reduce memory usage
1 parent 972d5d3 commit fe0eed2

1 file changed

Lines changed: 117 additions & 40 deletions

File tree

tpcav/cavs.py

Lines changed: 117 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from copy import deepcopy
2121
from scipy import stats
22+
import uuid
2223
from sklearn.linear_model import SGDClassifier
2324
from sklearn.metrics import precision_recall_fscore_support
2425
from sklearn.metrics.pairwise import cosine_similarity
@@ -202,9 +203,12 @@ def predict(self, x: np.ndarray) -> np.ndarray:
202203
return self.lm.predict(x)
203204

204205
def prepare_xy(concept_embeddings, control_embeddings, seed=42):
206+
def _to_numpy(emb):
207+
return np.load(str(emb), mmap_mode="r")
208+
205209
# move to CPU + numpy, just double confirm
206-
concept = concept_embeddings.detach().cpu().numpy()
207-
control = control_embeddings.detach().cpu().numpy()
210+
concept = _to_numpy(concept_embeddings)
211+
control = _to_numpy(control_embeddings)
208212

209213
# labels
210214
y_concept = np.zeros(len(concept), dtype=np.int64)
@@ -228,8 +232,8 @@ def prepare_xy(concept_embeddings, control_embeddings, seed=42):
228232
return X_train, y_train, X_test, y_test
229233

230234
def _train(
231-
concept_embeddings: torch.Tensor,
232-
control_embeddings: torch.Tensor,
235+
concept_embeddings: str,
236+
control_embeddings: str,
233237
output_dir: str,
234238
penalty: str = "l2",
235239
backend: str = "sklearn",
@@ -338,6 +342,44 @@ def set_control(self, control_concept, num_samples: int) -> torch.Tensor:
338342
)
339343
return self.control_embeddings
340344

345+
@staticmethod
346+
def _save_tensor_npy(path: Path, tensor: torch.Tensor) -> str:
347+
path.parent.mkdir(parents=True, exist_ok=True)
348+
np.save(path, tensor.detach().cpu().numpy())
349+
return str(path)
350+
351+
@staticmethod
352+
def _cleanup_paths(paths: list[str]) -> None:
353+
for p in paths:
354+
try:
355+
Path(p).unlink(missing_ok=True)
356+
except Exception:
357+
pass
358+
359+
@classmethod
360+
def _reap_done_futures(cls, futures: list):
361+
pending = []
362+
for name, fut, paths in futures:
363+
if fut.done():
364+
fut.result() # raises if worker failed
365+
cls._cleanup_paths(paths)
366+
else:
367+
pending.append((name, fut, paths))
368+
return pending
369+
370+
@classmethod
371+
def _wait_for_capacity(
372+
cls,
373+
futures: list,
374+
capacity: int,
375+
sleep_s: int = 5,
376+
):
377+
while True:
378+
futures = cls._reap_done_futures(futures)
379+
if len(futures) < capacity:
380+
return futures
381+
time.sleep(sleep_s)
382+
341383
def train_concepts(
342384
self,
343385
concept_list,
@@ -356,22 +398,33 @@ def train_concepts(
356398
else:
357399
self.control_embeddings = self.control_embeddings.cpu()
358400

401+
output_dir_path = Path(output_dir)
402+
output_dir_path.mkdir(parents=True, exist_ok=True)
403+
control_memmap_path = output_dir_path / f"_control_embeddings_{uuid.uuid4().hex}.npy"
404+
self._save_tensor_npy(control_memmap_path, self.control_embeddings)
405+
359406
if num_processes == 1:
360407
for c in concept_list:
361408
concept_embeddings = self.tpcav.concept_embeddings(
362409
c, num_samples=num_samples
363410
)
411+
concept_dir = output_dir_path / c.name
412+
concept_dir.mkdir(parents=True, exist_ok=True)
413+
concept_memmap_path = concept_dir / "concept_embeddings.npy"
414+
self._save_tensor_npy(concept_memmap_path, concept_embeddings)
364415
fscore, weight = _train(
365-
concept_embeddings.cpu(),
366-
self.control_embeddings.cpu(),
367-
Path(output_dir) / c.name,
416+
str(concept_memmap_path),
417+
str(control_memmap_path),
418+
concept_dir,
368419
self.penalty,
369420
backend=backend,
370421
device=device
371422
)
372423
self.cav_fscores[c.name] = fscore
373424
self.cav_weights[c.name] = weight
374425
self.cavs_list.append(weight)
426+
427+
self._cleanup_paths([str(concept_memmap_path)])
375428
else:
376429
futures = []
377430
ctx = mp.get_context("spawn")
@@ -381,36 +434,39 @@ def train_concepts(
381434
c, num_samples=num_samples
382435
)
383436

384-
# block the process to avoid too long queue
385-
while True:
386-
done = [f for (_, f) in futures if f.done()]
387-
for f in done:
388-
f.result() # raises if worker failed
389-
390-
pending = [f for (_, f) in futures if not f.done()]
391-
if len(pending) < (max_pending + num_processes):
392-
break
437+
concept_dir = output_dir_path / c.name
438+
concept_dir.mkdir(parents=True, exist_ok=True)
439+
concept_memmap_path = concept_dir / "concept_embeddings.npy"
440+
self._save_tensor_npy(concept_memmap_path, concept_embeddings)
393441

394-
time.sleep(5)
442+
# block the process to avoid too long queue
443+
futures = self._wait_for_capacity(
444+
futures, capacity=(max_pending + num_processes), sleep_s=5
445+
)
395446

396447
future = executor.submit(
397448
_train,
398-
concept_embeddings.cpu(),
399-
self.control_embeddings,
400-
Path(output_dir) / c.name,
449+
str(concept_memmap_path),
450+
str(control_memmap_path),
451+
concept_dir,
401452
self.penalty,
402453
backend=backend,
403454
device=device
404455
)
405456
logger.info("Submitted CAV training for concept %s", c.name)
406-
futures.append((c.name, future))
457+
futures.append((c.name, future, [str(concept_memmap_path)]))
407458

408-
results = [(name, f.result()) for name, f in futures]
459+
results = []
460+
for name, fut, paths in futures:
461+
results.append((name, fut.result()))
462+
self._cleanup_paths(paths)
409463
for name, (fscore, weight) in results:
410464
self.cav_fscores[name] = fscore
411465
self.cav_weights[name] = weight
412466
self.cavs_list.append(weight)
413467

468+
self._cleanup_paths([str(control_memmap_path)])
469+
414470
def train_concepts_pairs(self,
415471
concept_pair_list,
416472
num_samples: int,
@@ -423,6 +479,9 @@ def train_concepts_pairs(self,
423479
424480
Note: It would compute embeddings on every control concept, use self.train_concepts if control concept is fixed
425481
"""
482+
output_dir_path = Path(output_dir)
483+
output_dir_path.mkdir(parents=True, exist_ok=True)
484+
426485
if num_processes == 1:
427486
for c_test, c_control in concept_pair_list:
428487
concept_embeddings = self.tpcav.concept_embeddings(
@@ -432,17 +491,26 @@ def train_concepts_pairs(self,
432491
c_control, num_samples=num_samples
433492
)
434493

494+
concept_dir = output_dir_path / c_test.name
495+
concept_dir.mkdir(parents=True, exist_ok=True)
496+
concept_memmap_path = concept_dir / "concept_embeddings.npy"
497+
control_memmap_path = concept_dir / "control_embeddings.npy"
498+
self._save_tensor_npy(concept_memmap_path, concept_embeddings)
499+
self._save_tensor_npy(control_memmap_path, control_embeddings)
500+
435501
fscore, weight = _train(
436-
concept_embeddings.cpu(),
437-
control_embeddings.cpu(),
438-
Path(output_dir) / c_test.name,
502+
str(concept_memmap_path),
503+
str(control_memmap_path),
504+
concept_dir,
439505
self.penalty,
440506
backend=backend,
441507
device=device
442508
)
443509
self.cav_fscores[c_test.name] = fscore
444510
self.cav_weights[c_test.name] = weight
445511
self.cavs_list.append(weight)
512+
513+
self._cleanup_paths([str(concept_memmap_path), str(control_memmap_path)])
446514
else:
447515
futures = []
448516
with ProcessPoolExecutor(max_workers=num_processes) as executor:
@@ -454,31 +522,40 @@ def train_concepts_pairs(self,
454522
c_control, num_samples=num_samples
455523
)
456524

457-
# block the process to avoid too long queue
458-
while True:
459-
done = [f for (_, f) in futures if f.done()]
460-
for f in done:
461-
f.result() # raises if worker failed
462-
463-
pending = [f for (_, f) in futures if not f.done()]
464-
if len(pending) < (max_pending + num_processes):
465-
break
525+
concept_dir = output_dir_path / c_test.name
526+
concept_dir.mkdir(parents=True, exist_ok=True)
527+
concept_memmap_path = concept_dir / "concept_embeddings.npy"
528+
control_memmap_path = concept_dir / "control_embeddings.npy"
529+
self._save_tensor_npy(concept_memmap_path, concept_embeddings)
530+
self._save_tensor_npy(control_memmap_path, control_embeddings)
466531

467-
time.sleep(5)
532+
# block the process to avoid too long queue
533+
futures = self._wait_for_capacity(
534+
futures, capacity=(max_pending + num_processes), sleep_s=5
535+
)
468536

469537
future = executor.submit(
470538
_train,
471-
concept_embeddings.cpu(),
472-
control_embeddings.cpu(),
473-
Path(output_dir) / c_test.name,
539+
str(concept_memmap_path),
540+
str(control_memmap_path),
541+
concept_dir,
474542
self.penalty,
475543
backend=backend,
476544
device=device
477545
)
478546
logger.info("Submitted CAV training for concept %s", c_test.name)
479-
futures.append((c_test.name, future))
547+
futures.append(
548+
(
549+
c_test.name,
550+
future,
551+
[str(concept_memmap_path), str(control_memmap_path)],
552+
)
553+
)
480554

481-
results = [(name, f.result()) for name, f in futures]
555+
results = []
556+
for name, fut, paths in futures:
557+
results.append((name, fut.result()))
558+
self._cleanup_paths(paths)
482559
for name, (fscore, weight) in results:
483560
self.cav_fscores[name] = fscore
484561
self.cav_weights[name] = weight

0 commit comments

Comments
 (0)