@@ -65,10 +65,6 @@ def main():
6565 "output_prefix" ,
6666 help = "Prefix of output file name for attributions saved in npz format" ,
6767 )
68- parser .add_argument (
69- "output_key" ,
70- help = "Which output should be attributed to, available keys [Oct4_profile, Oct4_counts, Sox2_profile, Sox2_counts, Nanog_profile, Nanog_counts, Klf4_profile, Klf4_counts]" ,
71- )
7268 parser .add_argument (
7369 "--cavs-dir" ,
7470 type = str ,
@@ -168,9 +164,12 @@ def main():
168164
169165 # attribution each test sample
170166 # NOTE: there should be only one attribution tensor coming out of layer attribution
171- attributions = []
172- attributions_x_avs = []
173- attributions_remainder = []
167+ attributions_seq = []
168+ attributions_chrom = []
169+ attributions_x_avs_seq = []
170+ attributions_remainder_seq = []
171+ attributions_x_avs_chrom = []
172+ attributions_remainder_chrom = []
174173 regions_save = []
175174 for (region , seq , chrom , _ , _ ), (bseq , bchrom , _ , _ ) in tqdm (
176175 zip (target_dl , baseline_dl )
@@ -181,20 +180,23 @@ def main():
181180
182181 # match repeated input shape
183182 seq = torch .repeat_interleave (seq , repeats = args .num_baselines_per_sample , dim = 0 )
183+ chrom = torch .repeat_interleave (chrom , repeats = args .num_baselines_per_sample , dim = 0 )
184184 bseq = bseq [: seq .shape [0 ]]
185- inputs = utils .seq_transform_fn (seq .to (device ))
186- binputs = utils .seq_transform_fn (bseq .to (device ))
185+ bchrom = bchrom [: chrom .shape [0 ]]
187186
188- neutral_biases = {k : v for k , v in inputs .items () if k != "seq" }
187+ seq = utils .seq_transform_fn (seq .to (device ))
188+ bseq = utils .seq_transform_fn (bseq .to (device ))
189189
190+ chrom = utils .chrom_transform_fn (chrom .to (device ))
191+ bchrom = utils .chrom_transform_fn (bchrom .to (device ))
192+
193+ inputs = seq if chrom is None else (seq , chrom )
194+ binputs = bseq if chrom is None else (bseq , bchrom )
190195 # attribution on full sequence
191196 attribution = deeplift .attribute (
192- inputs [ "seq" ] ,
193- baselines = binputs [ "seq" ] ,
197+ inputs ,
198+ baselines = binputs ,
194199 additional_forward_args = (
195- neutral_biases ,
196- args .output_key ,
197- True ,
198200 cavs_list ,
199201 False ,
200202 False ,
@@ -203,68 +205,46 @@ def main():
203205 # None if args.no_multiply_by_inputs else abs_attribution_func
204206 # ),
205207 ) # [# batch, dim_projected+dim_residual]
206- attributions .append (
207- attribution .reshape (
208- - 1 , args .num_baselines_per_sample , * attribution .shape [1 :]
209- )
210- .mean (axis = 1 )
211- .detach ()
212- .cpu ()
213- )
208+
209+ def reduce_attrs (attrs ):
210+ return attrs .reshape (- 1 , args .num_baselines_per_sample , * attrs .shape [1 :]).mean (axis = 1 ).detach ().cpu ()
211+
212+ if chrom is None :
213+ attributions_seq .append (reduce_attrs (attribution ))
214+ else :
215+ attributions_seq .append (reduce_attrs (attribution [0 ]))
216+ attributions_chrom .append (reduce_attrs (attribution [1 ]))
214217
215218 if cavs_list is not None :
216219 # attribution on x avs directions
217220 attribution_x_avs = deeplift .attribute (
218- inputs [ "seq" ] ,
219- baselines = binputs [ "seq" ] ,
221+ inputs ,
222+ baselines = binputs ,
220223 additional_forward_args = (
221- neutral_biases ,
222- args .output_key ,
223- True ,
224224 cavs_list ,
225225 False ,
226226 True ,
227227 ),
228- # custom_attribution_func=(
229- # None if args.no_multiply_by_inputs else abs_attribution_func
230- # ),
231228 ) # [# batch, dim_projected+dim_residual]
232229 # attribution on remainder
233230 attribution_remainder = deeplift .attribute (
234- inputs [ "seq" ] ,
235- baselines = binputs [ "seq" ] ,
231+ inputs ,
232+ baselines = binputs ,
236233 additional_forward_args = (
237- neutral_biases ,
238- args .output_key ,
239- True ,
240234 cavs_list ,
241235 True ,
242236 False ,
243237 ),
244- # custom_attribution_func=(
245- # None if args.no_multiply_by_inputs else abs_attribution_func
246- # ),
247238 ) # [# batch, dim_projected+dim_residual]
248- attributions_x_avs .append (
249- attribution_x_avs .reshape (
250- - 1 , args .num_baselines_per_sample , * attribution_x_avs .shape [1 :]
251- )
252- .mean (axis = 1 )
253- .detach ()
254- .cpu ()
255- )
256- attributions_remainder .append (
257- attribution_remainder .reshape (
258- - 1 , args .num_baselines_per_sample , * attribution_remainder .shape [1 :]
259- )
260- .mean (axis = 1 )
261- .detach ()
262- .cpu ()
263- )
264239
265- # make predictions
266- # target_preds[output_key].append(tpcav_model(inpt_projected.to(device), avs_residual.to(device), args.output_key).detach().cpu())
267- # baseline_preds[output_key].append(tpcav_model(bavs_projected.to(device), bavs_residual.to(device), args.output_key).detach().cpu())
240+ if chrom is None :
241+ attributions_x_avs_seq .append (reduce_attrs (attribution_x_avs ))
242+ attributions_remainder_seq .append (reduce_attrs (attribution_remainder ))
243+ else :
244+ attributions_x_avs_seq .append (reduce_attrs (attribution_x_avs [0 ]))
245+ attributions_x_avs_chrom .append (reduce_attrs (attribution_x_avs [1 ]))
246+ attributions_remainder_seq .append (reduce_attrs (attribution_remainder [0 ]))
247+ attributions_remainder_chrom .append (reduce_attrs (attribution_remainder [1 ]))
268248
269249 with torch .no_grad ():
270250 del (
@@ -274,6 +254,9 @@ def main():
274254 )
275255 torch .cuda .empty_cache ()
276256
257+ # save regions
258+ np .savetxt (f"{ args .output_prefix } .regions.txt" , regions_save , fmt = "%s" )
259+
277260 # save attributions
278261 def save_attrs (attrs , name ):
279262 attrs = torch .cat (attrs )
@@ -282,72 +265,50 @@ def save_attrs(attrs, name):
282265 return attrs
283266
284267 # sum over the last dimension to get per base pair attributions
285- attrs_all = save_attrs (attributions , "attributions " ).sum (dim = 2 )
286- # save regions
287- np . savetxt ( f" { args . output_prefix } .regions.txt" , regions_save , fmt = "%s" )
268+ attrs_all_seq = save_attrs (attributions_seq , "attributions_seq " ).sum (dim = 2 )
269+ if chrom is not None :
270+ attrs_all_chrom = save_attrs ( attributions_chrom , "attributions_chrom" ). sum ( dim = 2 )
288271
289272 if cavs_list is not None :
290- attrs_x_avs = save_attrs (attributions_x_avs , "attributions_x_avs " ).sum (dim = 2 )
291- attrs_remainder = save_attrs (
292- attributions_remainder , "attributions_remainder "
273+ attrs_x_avs_seq = save_attrs (attributions_x_avs_seq , "attributions_x_avs_seq " ).sum (dim = 2 )
274+ attrs_remainder_seq = save_attrs (
275+ attributions_remainder_seq , "attributions_remainder_seq "
293276 ).sum (dim = 2 )
277+ if chrom is not None :
278+ attrs_x_avs_chrom = save_attrs (attributions_x_avs_chrom , "attributions_x_avs_chrom" ).sum (dim = 2 )
279+ attrs_remainder_chrom = save_attrs (
280+ attributions_remainder_chrom , "attributions_remainder_chrom"
281+ ).sum (dim = 2 )
294282
295- # print summary statistics
296- def compute_attr_contrib (sign = "+" ):
297- idx = attrs_all < 0 if sign == "-" else attrs_all > 0
298- attrs_all_signed = attrs_all [idx ]
299- attrs_x_avs_signed = attrs_x_avs [idx ]
300- attrs_x_avs_signed [
301- (attrs_x_avs_signed > 0 ) if sign == "-" else (attrs_x_avs_signed < 0 )
302- ] = 0 # set impatible signed attrs as 0
303- attrs_x_avs_contrib = (
304- attrs_x_avs_signed / attrs_all_signed
305- ) # get element-wise contribution ratio
306- attrs_x_avs_contrib [attrs_x_avs_contrib > 1 ] = (
307- 1 # ceiling the max ratio as 1
308- )
309- print (
310- f"{ sign } contribution ratio of x avs attributions to all attributions: { attrs_x_avs_contrib .mean ():.3f} "
311- )
312- return attrs_x_avs_contrib , idx
313-
314- pos_contrib_ratio , pos_contrib_index = compute_attr_contrib (sign = "+" )
315- neg_contrib_ratio , neg_contrib_indx = compute_attr_contrib (sign = "-" )
316-
317- with open (f"{ args .output_prefix } .contrib_ratio.txt" , "w" ) as f :
318- f .write (
319- f"Positive contribution ratio of x avs attributions to all attributions: { pos_contrib_ratio .mean ().item ():.3f} \n "
320- )
321- f .write (
322- f"Negative contribution ratio of x avs attributions to all attributions: { neg_contrib_ratio .mean ().item ():.3f} \n "
283+ # save attr x avs and total attrs per region
284+ if chrom is not None :
285+ pd .DataFrame (
286+ {
287+ "region" : regions_save ,
288+ "attrs_total_seq" : attrs_all_seq .sum (dim = 1 ).numpy (),
289+ "attrs_x_avs_seq" : attrs_x_avs_seq .sum (dim = 1 ).numpy (),
290+ "attrs_total_chrom" : attrs_all_chrom .sum (dim = 1 ).numpy (),
291+ "attrs_x_avs_chrom" : attrs_x_avs_chrom .sum (dim = 1 ).numpy (),
292+ }
293+ ).assign (attrs_x_avs_total = lambda x : x ['attrs_x_avs_seq' ] + x ['attrs_x_avs_chrom' ]).sort_values ("attrs_x_avs_total" , ascending = False ).to_csv (
294+ f"{ args .output_prefix } .regions_with_attrs_x_avs.txt" ,
295+ index = False ,
296+ header = True ,
297+ sep = "\t " ,
323298 )
324- f .write (
325- f"Total contribution ratio of x avs attributions to all attributions: { torch .cat ([pos_contrib_ratio , neg_contrib_ratio ]).mean ().item ():.3f} \n "
299+ else :
300+ pd .DataFrame (
301+ {
302+ "region" : regions_save ,
303+ "attrs_total_seq" : attrs_all_seq .sum (dim = 1 ).numpy (),
304+ "attrs_x_avs_seq" : attrs_x_avs_seq .sum (dim = 1 ).numpy (),
305+ }
306+ ).sort_values ("attrs_x_avs_seq" , ascending = False ).to_csv (
307+ f"{ args .output_prefix } .regions_with_attrs_x_avs.txt" ,
308+ index = False ,
309+ header = True ,
310+ sep = "\t " ,
326311 )
327- # save regions with the attrib ratios
328- contrib_ratio = torch .zeros_like (attrs_all )
329- assert len (contrib_ratio .shape ) == 2
330- contrib_ratio [pos_contrib_index ] = pos_contrib_ratio
331- contrib_ratio [neg_contrib_indx ] = neg_contrib_ratio
332- contrib_ratio_per_region = contrib_ratio .mean (dim = 1 )
333-
334- with open (f"{ args .output_prefix } .regions_with_contrib.txt" , "w" ) as o :
335- for r , cr in zip (regions_save , contrib_ratio_per_region ):
336- o .write (f"{ r } \t { cr .item ()} \n " )
337-
338- # save attr x avs and total attrs per region
339- pd .DataFrame (
340- {
341- "region" : regions_save ,
342- "attrs_total" : attrs_all .sum (dim = 1 ).numpy (),
343- "attrs_x_avs" : attrs_x_avs .sum (dim = 1 ).numpy (),
344- }
345- ).sort_values ("attrs_x_avs" , ascending = False ).to_csv (
346- f"{ args .output_prefix } .regions_with_attrs_x_avs.txt" ,
347- index = False ,
348- header = True ,
349- sep = "\t " ,
350- )
351312
352313
353314if __name__ == "__main__" :
0 commit comments