Skip to content

Commit 0d1e592

Browse files
committed
add help functions
1 parent 0df3f9c commit 0d1e592

3 files changed

Lines changed: 35 additions & 13 deletions

File tree

README.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
# TPCAV (Testing with PCA projected Concept Activation Vectors)
22

3-
Analysis pipeline for TPCAV
3+
This repository contains code to compute TPCAV (Testing with PCA projected Concept Activation Vectors) on deep learning models. TPCAV is an extension of the original TCAV method, which uses PCA to reduce the dimensionality of the activations at a selected intermediate layer before computing Concept Activation Vectors (CAVs)
44

5-
## Dependencies
5+
## Installation
66

7-
You can use your own environment for the model, in addition, you need to install the following packages:
87

9-
- captum 0.7
10-
- seqchromloader 0.8.5
11-
- scikit-learn 1.5.2
128

139
## Workflow
1410

tpcav/cavs.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,16 @@ def tpcav_score(
246246

247247
return scores
248248

249+
def tpcav_score_all_concepts(self, attributions: torch.Tensor) -> dict:
250+
"""
251+
Compute TCAV scores for all trained concepts.
252+
"""
253+
scores_dict = {}
254+
for concept_name in self.cav_weights.keys():
255+
scores = self.tpcav_score(concept_name, attributions)
256+
scores_dict[concept_name] = scores
257+
return scores_dict
258+
249259
def tpcav_score_binary_log_ratio(
250260
self, concept_name: str, attributions: torch.Tensor, pseudocount: float = 1.0
251261
) -> float:
@@ -259,6 +269,20 @@ def tpcav_score_binary_log_ratio(
259269

260270
return np.log((pos_count + pseudocount) / (neg_count + pseudocount))
261271

272+
def tpcav_score_all_concepts_log_ratio(
273+
self, attributions: torch.Tensor, pseudocount: float = 1.0
274+
) -> dict:
275+
"""
276+
Compute TCAV log ratio scores for all trained concepts.
277+
"""
278+
log_ratio_dict = {}
279+
for concept_name in self.cav_weights.keys():
280+
log_ratio = self.tpcav_score_binary_log_ratio(
281+
concept_name, attributions, pseudocount
282+
)
283+
log_ratio_dict[concept_name] = log_ratio
284+
return log_ratio_dict
285+
262286
def plot_cavs_similaritiy_heatmap(
263287
self,
264288
attributions: torch.Tensor,

tpcav/tpcav_model.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@
1111

1212
def _abs_attribution_func(multipliers, inputs, baselines):
1313
"Multiplier x abs(inputs - baselines) to avoid double-sign effects."
14-
# print(f"inputs: {inputs[1][:5]}")
15-
# print(f"baselines: {baselines[1][:5]}")
16-
# print(f"multipliers: {multipliers[0][:5]}")
17-
# print(f"multipliers: {multipliers[1][:5]}")
1814
return tuple(
1915
(input_ - baseline).abs() * multiplier
2016
for input_, baseline, multiplier in zip(inputs, baselines, multipliers)
@@ -174,8 +170,12 @@ def layer_attributions(
174170
target_batches: Iterable,
175171
baseline_batches: Iterable,
176172
multiply_by_inputs: bool = True,
173+
abs_inputs_diff: bool = True,
177174
) -> Dict[str, torch.Tensor]:
178-
"""Compute DeepLift attributions on PCA embedding space.
175+
"""
176+
Compute DeepLift attributions on PCA embedding space.
177+
178+
By default, it computes (input - baseline).abs() * multiplier to avoid double-sign effects (abs_inputs_diff=True).
179179
180180
target_batches and baseline_batches should yield (seq, chrom) pairs of matching length.
181181
"""
@@ -184,6 +184,8 @@ def layer_attributions(
184184
self.forward = self.forward_from_embeddings_at_layer
185185
deeplift = DeepLift(self, multiply_by_inputs=multiply_by_inputs)
186186

187+
custom_attr_func = _abs_attribution_func if abs_inputs_diff else None
188+
187189
attributions = []
188190
for inputs, binputs in zip(target_batches, baseline_batches):
189191
avs = self._layer_output(*[i.to(self.device) for i in inputs])
@@ -205,7 +207,7 @@ def layer_attributions(
205207
),
206208
additional_forward_args=(inputs,),
207209
custom_attribution_func=(
208-
None if not multiply_by_inputs else _abs_attribution_func
210+
None if not multiply_by_inputs else custom_attr_func
209211
),
210212
)
211213
attr_residual, attr_projected = attribution
@@ -219,7 +221,7 @@ def layer_attributions(
219221
inputs,
220222
),
221223
custom_attribution_func=(
222-
None if not multiply_by_inputs else _abs_attribution_func
224+
None if not multiply_by_inputs else custom_attr_func
223225
),
224226
)[0]
225227

0 commit comments

Comments
 (0)