Skip to content

Commit 0df3f9c

Browse files
committed
fix the bug causes avs_projected grad lost
1 parent f90b5c8 commit 0df3f9c

2 files changed

Lines changed: 59 additions & 4 deletions

File tree

test/test_cav_trainer.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import unittest
2+
from functools import partial
23
from pathlib import Path
34

45
import torch
56
from Bio import motifs as Bio_motifs
7+
from captum.attr import DeepLift
68

79
from tpcav import helper
810
from tpcav.cavs import CavTrainer
911
from tpcav.concepts import ConceptBuilder
10-
from tpcav.tpcav_model import TPCAV
12+
from tpcav.tpcav_model import TPCAV, _abs_attribution_func
1113

1214

1315
class DummyModelSeq(torch.nn.Module):
@@ -22,6 +24,11 @@ def forward(self, seq):
2224
y_hat = self.layer2(y_hat)
2325
return y_hat
2426

27+
def foward_from_layer1(self, y_hat):
28+
y_hat = y_hat.squeeze(-1)
29+
y_hat = self.layer2(y_hat)
30+
return y_hat
31+
2532

2633
class DummyModelSeqChrom(torch.nn.Module):
2734
def __init__(self):
@@ -174,6 +181,49 @@ def pack_data_iters(df):
174181
],
175182
)
176183

184+
# compute layer attributions using the old way
185+
random1_avs = []
186+
random2_avs = []
187+
for inputs in pack_data_iters(random_regions_1):
188+
av = tpcav_model._layer_output(*[i.to(tpcav_model.device) for i in inputs])
189+
random1_avs.append(av.detach().cpu())
190+
for inputs in pack_data_iters(random_regions_2):
191+
av = tpcav_model._layer_output(*[i.to(tpcav_model.device) for i in inputs])
192+
random2_avs.append(av.detach().cpu())
193+
random1_avs = torch.cat(random1_avs, dim=0)
194+
random2_avs = torch.cat(random2_avs, dim=0)
195+
196+
random1_avs_residual, random1_avs_projected = tpcav_model.project_activations(
197+
random1_avs
198+
)
199+
random2_avs_residual, random2_avs_projected = tpcav_model.project_activations(
200+
random2_avs
201+
)
202+
203+
def forward_from_layer_1_embeddings(tm, avs_residual, avs_projected):
204+
y_hat = tm.embedding_to_layer_activation(avs_residual, avs_projected)
205+
y_hat = tm.model.foward_from_layer1(y_hat)
206+
return y_hat
207+
208+
tpcav_model.forward = partial(forward_from_layer_1_embeddings, tpcav_model)
209+
210+
dl = DeepLift(tpcav_model)
211+
attributions_old = dl.attribute(
212+
(
213+
random1_avs_residual.to(tpcav_model.device),
214+
random1_avs_projected.to(tpcav_model.device),
215+
),
216+
baselines=(
217+
random2_avs_residual.to(tpcav_model.device),
218+
random2_avs_projected.to(tpcav_model.device),
219+
),
220+
custom_attribution_func=_abs_attribution_func,
221+
)
222+
attr_residual, attr_projected = attributions_old
223+
attributions_old = torch.cat((attr_projected, attr_residual), dim=1)
224+
225+
self.assertTrue(torch.allclose(attributions.cpu(), attributions_old.cpu()))
226+
177227

178228
if __name__ == "__main__":
179229
unittest.main()

tpcav/tpcav_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import inspect
21
import logging
32
from functools import partial
43
from typing import Dict, Iterable, List, Optional, Tuple
@@ -12,6 +11,10 @@
1211

1312
def _abs_attribution_func(multipliers, inputs, baselines):
1413
"Multiplier x abs(inputs - baselines) to avoid double-sign effects."
14+
# print(f"inputs: {inputs[1][:5]}")
15+
# print(f"baselines: {baselines[1][:5]}")
16+
# print(f"multipliers: {multipliers[0][:5]}")
17+
# print(f"multipliers: {multipliers[1][:5]}")
1518
return tuple(
1619
(input_ - baseline).abs() * multiplier
1720
for input_, baseline, multiplier in zip(inputs, baselines, multipliers)
@@ -61,7 +64,6 @@ def restore_tpcav_state(self, tpcav_state_dict: Dict) -> None:
6164
self._set_buffer("pca_inv", tpcav_state_dict["pca_inv"])
6265
self._set_buffer("orig_shape", tpcav_state_dict["orig_shape"])
6366
self.fitted = True
64-
print(inspect.currentframe().f_back.f_code.co_name)
6567
logger.warning(
6668
"Restored TPCAV state, please set model attribute!\n\n Example: self.model = Model_class()",
6769
)
@@ -190,7 +192,11 @@ def layer_attributions(
190192
bavs = self._layer_output(*[bi.to(self.device) for bi in binputs])
191193
bavs_residual, bavs_projected = self.project_activations(bavs)
192194

195+
# detach the projected tensor as it's connnected to the original input graph,
196+
# detaching it would keep the gradients on it
193197
if avs_projected is not None:
198+
avs_projected = avs_projected.detach()
199+
bavs_projected = bavs_projected.detach()
194200
attribution = deeplift.attribute(
195201
(avs_residual.to(self.device), avs_projected.to(self.device)),
196202
baselines=(
@@ -258,7 +264,6 @@ def input_attributions(
258264

259265
attributions = []
260266
for inputs, binputs in zip(target_batches, baseline_batches):
261-
262267
attribution = deeplift.attribute(
263268
tuple([i.to(self.device) for i in inputs]),
264269
baselines=tuple([bi.to(self.device) for bi in binputs]),

0 commit comments

Comments
 (0)