11import unittest
2+ from functools import partial
23from pathlib import Path
34
45import torch
56from Bio import motifs as Bio_motifs
7+ from captum .attr import DeepLift
68
79from tpcav import helper
810from tpcav .cavs import CavTrainer
911from tpcav .concepts import ConceptBuilder
10- from tpcav .tpcav_model import TPCAV
12+ from tpcav .tpcav_model import TPCAV , _abs_attribution_func
1113
1214
1315class DummyModelSeq (torch .nn .Module ):
@@ -22,6 +24,11 @@ def forward(self, seq):
2224 y_hat = self .layer2 (y_hat )
2325 return y_hat
2426
27+ def foward_from_layer1 (self , y_hat ):
28+ y_hat = y_hat .squeeze (- 1 )
29+ y_hat = self .layer2 (y_hat )
30+ return y_hat
31+
2532
2633class DummyModelSeqChrom (torch .nn .Module ):
2734 def __init__ (self ):
@@ -174,6 +181,49 @@ def pack_data_iters(df):
174181 ],
175182 )
176183
184+ # compute layer attributions using the old way
185+ random1_avs = []
186+ random2_avs = []
187+ for inputs in pack_data_iters (random_regions_1 ):
188+ av = tpcav_model ._layer_output (* [i .to (tpcav_model .device ) for i in inputs ])
189+ random1_avs .append (av .detach ().cpu ())
190+ for inputs in pack_data_iters (random_regions_2 ):
191+ av = tpcav_model ._layer_output (* [i .to (tpcav_model .device ) for i in inputs ])
192+ random2_avs .append (av .detach ().cpu ())
193+ random1_avs = torch .cat (random1_avs , dim = 0 )
194+ random2_avs = torch .cat (random2_avs , dim = 0 )
195+
196+ random1_avs_residual , random1_avs_projected = tpcav_model .project_activations (
197+ random1_avs
198+ )
199+ random2_avs_residual , random2_avs_projected = tpcav_model .project_activations (
200+ random2_avs
201+ )
202+
203+ def forward_from_layer_1_embeddings (tm , avs_residual , avs_projected ):
204+ y_hat = tm .embedding_to_layer_activation (avs_residual , avs_projected )
205+ y_hat = tm .model .foward_from_layer1 (y_hat )
206+ return y_hat
207+
208+ tpcav_model .forward = partial (forward_from_layer_1_embeddings , tpcav_model )
209+
210+ dl = DeepLift (tpcav_model )
211+ attributions_old = dl .attribute (
212+ (
213+ random1_avs_residual .to (tpcav_model .device ),
214+ random1_avs_projected .to (tpcav_model .device ),
215+ ),
216+ baselines = (
217+ random2_avs_residual .to (tpcav_model .device ),
218+ random2_avs_projected .to (tpcav_model .device ),
219+ ),
220+ custom_attribution_func = _abs_attribution_func ,
221+ )
222+ attr_residual , attr_projected = attributions_old
223+ attributions_old = torch .cat ((attr_projected , attr_residual ), dim = 1 )
224+
225+ self .assertTrue (torch .allclose (attributions .cpu (), attributions_old .cpu ()))
226+
177227
178228if __name__ == "__main__" :
179229 unittest .main ()
0 commit comments