Skip to content

Commit 8ec6468

Browse files
committed
improve naming
1 parent 2c7824a commit 8ec6468

1 file changed

Lines changed: 11 additions & 11 deletions

File tree

tpcav/cavs.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ class CavTrainer:
154154
def __init__(self, tpcav: TPCAV, penalty: str = "l2") -> None:
155155
self.tpcav = tpcav
156156
self.penalty = penalty
157-
self.cavs_fscores = {}
157+
self.cav_fscores = {}
158158
self.cav_weights = {}
159159
self.control_embeddings: Optional[torch.Tensor] = None
160160
self.cavs_list: List[torch.Tensor] = []
@@ -165,7 +165,7 @@ def save_state(self, output_path: str = "cav_trainer_state.pt"):
165165
"""
166166
state = {
167167
"penalty": self.penalty,
168-
"cavs_fscores": self.cavs_fscores,
168+
"cav_fscores": self.cav_fscores,
169169
"cav_weights": self.cav_weights,
170170
"control_embeddings": self.control_embeddings,
171171
"cavs_list": self.cavs_list,
@@ -180,7 +180,7 @@ def load_state(tpcav_model: TPCAV, state_path: str = "cav_trainer_state.pt"):
180180
state = torch.load(state_path, map_location="cpu")
181181
cav_trainer = CavTrainer(tpcav_model, penalty=state["penalty"])
182182

183-
cav_trainer.cavs_fscores = state["cavs_fscores"]
183+
cav_trainer.cav_fscores = state["cav_fscores"]
184184
cav_trainer.cav_weights = state["cav_weights"]
185185
cav_trainer.control_embeddings = state["control_embeddings"]
186186
cav_trainer.cavs_list = state["cavs_list"]
@@ -224,7 +224,7 @@ def train_concepts(
224224
Path(output_dir) / c.name,
225225
self.penalty,
226226
)
227-
self.cavs_fscores[c.name] = fscore
227+
self.cav_fscores[c.name] = fscore
228228
self.cav_weights[c.name] = weight
229229
self.cavs_list.append(weight)
230230
else:
@@ -259,7 +259,7 @@ def train_concepts(
259259

260260
results = [(name, f.result()) for name, f in futures]
261261
for name, (fscore, weight) in results:
262-
self.cavs_fscores[name] = fscore
262+
self.cav_fscores[name] = fscore
263263
self.cav_weights[name] = weight
264264
self.cavs_list.append(weight)
265265

@@ -288,7 +288,7 @@ def train_concepts_pairs(self,
288288
Path(output_dir) / c_test.name,
289289
self.penalty,
290290
)
291-
self.cavs_fscores[c_test.name] = fscore
291+
self.cav_fscores[c_test.name] = fscore
292292
self.cav_weights[c_test.name] = weight
293293
self.cavs_list.append(weight)
294294
else:
@@ -326,7 +326,7 @@ def train_concepts_pairs(self,
326326

327327
results = [(name, f.result()) for name, f in futures]
328328
for name, (fscore, weight) in results:
329-
self.cavs_fscores[name] = fscore
329+
self.cav_fscores[name] = fscore
330330
self.cav_weights[name] = weight
331331
self.cavs_list.append(weight)
332332

@@ -397,14 +397,14 @@ def plot_cavs_similaritiy_heatmap(
397397
cavs_pass = []
398398
cavs_names_pass = []
399399
for cname in cavs_names:
400-
if self.cavs_fscores[cname] >= fscore_thresh:
400+
if self.cav_fscores[cname] >= fscore_thresh:
401401
cavs_pass.append(self.cav_weights[cname].cpu().numpy())
402402
cavs_names_pass.append(cname)
403403
else:
404404
logger.info(
405405
"Skipping CAV %s with F-score %.3f below threshold %.3f",
406406
cname,
407-
self.cavs_fscores[cname],
407+
self.cav_fscores[cname],
408408
fscore_thresh,
409409
)
410410
if len(cavs_pass) == 0:
@@ -533,8 +533,8 @@ def compute_motif_auc_fscore(num_motif_insertions: List[int], cav_trainers: List
533533

534534
assert motif_file_fmt in ['meme', 'consensus']
535535

536-
cavs_fscores_df = pd.DataFrame({nm: cav_trainer.cavs_fscores for nm, cav_trainer in zip(num_motif_insertions, cav_trainers)})
537-
cavs_fscores_df['concept'] = list(cav_trainers[0].cavs_fscores.keys())
536+
cavs_fscores_df = pd.DataFrame({nm: cav_trainer.cav_fscores for nm, cav_trainer in zip(num_motif_insertions, cav_trainers)})
537+
cavs_fscores_df['concept'] = list(cav_trainers[0].cav_fscores.keys())
538538

539539
def compute_auc_fscore(row):
540540
y = [row[nm] for nm in num_motif_insertions]

0 commit comments

Comments
 (0)