Skip to content

Commit 364578e

Browse files
committed
minor bug fixes & imporve log message
1 parent 5244ec3 commit 364578e

3 files changed

Lines changed: 25 additions & 20 deletions

File tree

test/test_cav_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def test_all(self):
279279
cav_trainer.set_control(builder.control_concepts[0], num_samples=100)
280280

281281
cav_trainer.train_concepts(
282-
builder.concepts, 100, output_dir="data/cavs/", num_processes=2, backend='torch', device='cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0',
282+
builder.concepts, 100, output_dir="data/cavs/", num_processes=1, backend='torch', device='cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0',
283283
)
284284
cav_trainer.train_concepts(
285285
builder.concepts, 100, output_dir="data/cavs/", num_processes=2

tpcav/cavs.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def _train(
238238
penalty: str = "l2",
239239
backend: str = "sklearn",
240240
device=None,
241+
name=None,
241242
) -> Tuple[float, torch.Tensor]:
242243
"""
243244
Train a binary CAV classifier for a concept vs cached control embeddings.
@@ -270,16 +271,18 @@ def _eval(avs, l, name: str):
270271
precision, recall, fscore, support = precision_recall_fscore_support(
271272
l, y_preds, average="binary", pos_label=1
272273
)
273-
logger.info("[%s] Accuracy: %.4f", name, acc)
274+
#logger.info("[%s] Accuracy: %.4f", name, acc)
274275
(output_dir / f"classifier_perform_on_{name}.txt").write_text(
275276
f"Accuracy: {acc}\n"
276277
)
277278
return fscore
278279

279280
output_dir.mkdir(parents=True, exist_ok=True)
280-
_eval(train_avs, train_l, "train")
281+
train_fscore = _eval(train_avs, train_l, "train")
281282
test_fscore = _eval(test_avs, test_l, "test")
282283

284+
logger.info("Concept %s: [train] F-score: %.4f, [test] F-score: %.4f", name, train_fscore, test_fscore)
285+
283286
weights = clf.weights
284287
assert len(weights.shape) == 2 and weights.shape[0] == 2
285288
torch.save(weights, output_dir / "classifier_weights.pt")
@@ -357,11 +360,11 @@ def _cleanup_paths(paths: list[str]) -> None:
357360
pass
358361

359362
@classmethod
360-
def _reap_done_futures(cls, futures: list):
363+
def _reap_done_futures(cls, futures: list, results: list):
361364
pending = []
362365
for name, fut, paths in futures:
363366
if fut.done():
364-
fut.result() # raises if worker failed
367+
results.append((name, fut.result())) # raises if worker failed
365368
cls._cleanup_paths(paths)
366369
else:
367370
pending.append((name, fut, paths))
@@ -371,11 +374,12 @@ def _reap_done_futures(cls, futures: list):
371374
def _wait_for_capacity(
372375
cls,
373376
futures: list,
377+
results: list,
374378
capacity: int,
375379
sleep_s: int = 5,
376380
):
377381
while True:
378-
futures = cls._reap_done_futures(futures)
382+
futures = cls._reap_done_futures(futures, results)
379383
if len(futures) < capacity:
380384
return futures
381385
time.sleep(sleep_s)
@@ -418,15 +422,16 @@ def train_concepts(
418422
concept_dir,
419423
self.penalty,
420424
backend=backend,
421-
device=device
425+
device=device,
426+
name=c.name,
422427
)
423428
self.cav_fscores[c.name] = fscore
424429
self.cav_weights[c.name] = weight
425430
self.cavs_list.append(weight)
426431

427432
self._cleanup_paths([str(concept_memmap_path)])
428433
else:
429-
futures = []
434+
futures = []; results = []
430435
ctx = mp.get_context("spawn")
431436
with ProcessPoolExecutor(mp_context=ctx, max_workers=num_processes) as executor:
432437
for c in concept_list:
@@ -441,7 +446,7 @@ def train_concepts(
441446

442447
# block the process to avoid too long queue
443448
futures = self._wait_for_capacity(
444-
futures, capacity=(max_pending + num_processes), sleep_s=5
449+
futures, results, capacity=(max_pending + num_processes), sleep_s=5
445450
)
446451

447452
future = executor.submit(
@@ -451,12 +456,12 @@ def train_concepts(
451456
concept_dir,
452457
self.penalty,
453458
backend=backend,
454-
device=device
459+
device=device,
460+
name=c.name,
455461
)
456462
logger.info("Submitted CAV training for concept %s", c.name)
457463
futures.append((c.name, future, [str(concept_memmap_path)]))
458464

459-
results = []
460465
for name, fut, paths in futures:
461466
results.append((name, fut.result()))
462467
self._cleanup_paths(paths)
@@ -504,15 +509,16 @@ def train_concepts_pairs(self,
504509
concept_dir,
505510
self.penalty,
506511
backend=backend,
507-
device=device
512+
device=device,
513+
name=c_test.name,
508514
)
509515
self.cav_fscores[c_test.name] = fscore
510516
self.cav_weights[c_test.name] = weight
511517
self.cavs_list.append(weight)
512518

513519
self._cleanup_paths([str(concept_memmap_path), str(control_memmap_path)])
514520
else:
515-
futures = []
521+
futures = []; results = []
516522
with ProcessPoolExecutor(max_workers=num_processes) as executor:
517523
for c_test, c_control in concept_pair_list:
518524
concept_embeddings = self.tpcav.concept_embeddings(
@@ -531,7 +537,7 @@ def train_concepts_pairs(self,
531537

532538
# block the process to avoid too long queue
533539
futures = self._wait_for_capacity(
534-
futures, capacity=(max_pending + num_processes), sleep_s=5
540+
futures, results, capacity=(max_pending + num_processes), sleep_s=5
535541
)
536542

537543
future = executor.submit(
@@ -541,7 +547,8 @@ def train_concepts_pairs(self,
541547
concept_dir,
542548
self.penalty,
543549
backend=backend,
544-
device=device
550+
device=device,
551+
name=c_test.name,
545552
)
546553
logger.info("Submitted CAV training for concept %s", c_test.name)
547554
futures.append(
@@ -552,7 +559,6 @@ def train_concepts_pairs(self,
552559
)
553560
)
554561

555-
results = []
556562
for name, fut, paths in futures:
557563
results.append((name, fut.result()))
558564
self._cleanup_paths(paths)

tpcav/report.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def compute_ic(row: Any) -> float:
8383
buf = io.BytesIO()
8484
fig.savefig(buf, format="png", dpi=200, bbox_inches="tight", transparent=True)
8585
plt.close(fig)
86-
out[str(name)] = "data:image/png;base64," + base64.b64encode(
86+
out[_utils.clean_motif_name(str(name))] = "data:image/png;base64," + base64.b64encode(
8787
buf.getvalue()
8888
).decode("ascii")
8989

@@ -343,10 +343,9 @@ def _new_trainer() -> Any:
343343
# -----------------------------------------------------------------------------
344344
# 5) Build JS payload (used by Plotly)
345345
# -----------------------------------------------------------------------------
346-
motif_logo_concepts = selected_motif_concepts[:]
347346
motif_logo_dict = _maybe_build_motif_logo_data_uris(
348347
motif_file if motif_file_fmt == "meme" else None,
349-
motif_logo_concepts,
348+
selected_motif_concepts,
350349
)
351350
js_payload: dict[str, Any] = {
352351
"motif_file_fmt": motif_file_fmt,
@@ -435,7 +434,7 @@ def _to_list(x: Any) -> Any:
435434
if motif_auc_df is not None:
436435
# append motif logo column if exists
437436
if motif_file_fmt=='meme' and (len(motif_logo_dict)>0):
438-
motif_auc_df['motif_logo'] = motif_auc_df.apply(lambda x: "<img src=\"" + motif_logo_dict[x['concept']] + "\" width=\"100\">", axis=1)
437+
motif_auc_df['motif_logo'] = motif_auc_df.apply(lambda x: "<img src=\"" + motif_logo_dict.get(x['concept'], 'null') + "\" width=\"100\">", axis=1)
439438
motif_auc_table_html = _render_df_table(motif_auc_df, max_rows=5000)
440439

441440
if embed_images:

0 commit comments

Comments
 (0)