Skip to content

Commit f0ddc85

Browse files
committed
Add script for input attrs x avs
Add a new script that can divide input attribution scores into two parts: - attributions coming from activations in parallel with cavs - remaining part that is orthogonal to cavs
1 parent 28327dc commit f0ddc85

1 file changed

Lines changed: 354 additions & 0 deletions

File tree

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
#!/usr/bin/env python3
2+
3+
DESCRIPTION = """
4+
Provide test regions and cavs, compute attribution scores on the input that in parallel to the cav directions and those orthogonal to cav directions.
5+
"""
6+
7+
8+
import logging
9+
import os
10+
import warnings
11+
from argparse import ArgumentParser
12+
from glob import glob
13+
14+
import numpy as np
15+
import pandas as pd
16+
import seqchromloader as scl
17+
import torch
18+
import utils
19+
from captum.attr import DeepLift
20+
from pybedtools import BedTool
21+
from sklearn.metrics import precision_recall_fscore_support as score
22+
from tqdm import tqdm
23+
24+
warnings.simplefilter(action="ignore", category=FutureWarning)
25+
26+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
27+
28+
SEED = 1001
29+
random_state = np.random.RandomState(SEED)
30+
31+
logging.basicConfig(
32+
format="%(asctime)s %(levelname)-8s %(message)s",
33+
level=logging.INFO,
34+
datefmt="%Y-%m-%d %H:%M:%S",
35+
)
36+
logger = logging.getLogger(__name__)
37+
logger.setLevel("INFO")
38+
logger.info("Compute Layer Attrs~")
39+
40+
41+
def abs_attribution_func(multipliers, inputs, baselines):
42+
"Multiplier x abs(inputs - baselines), this is to avoid duplicate sign coming from both inputs-baselines and concept cavs"
43+
attributions = tuple(
44+
(input - baseline).abs() * multiplier
45+
for input, baseline, multiplier in zip(inputs, baselines, multipliers)
46+
)
47+
return attributions
48+
49+
50+
def main():
51+
parser = ArgumentParser(description=DESCRIPTION)
52+
parser.add_argument("tpcav_model", help="TPCAV model")
53+
parser.add_argument(
54+
"test_bed",
55+
type=str,
56+
default=None,
57+
help="Test bed files will compute layer attributions on",
58+
)
59+
parser.add_argument(
60+
"input_window_length", default=1024, type=int, help="input window length"
61+
)
62+
parser.add_argument("genome_fasta_file", help="Genome fasta file")
63+
parser.add_argument("genome_size_file", help="Genome size file")
64+
parser.add_argument(
65+
"output_prefix",
66+
help="Prefix of output file name for attributions saved in npz format",
67+
)
68+
parser.add_argument(
69+
"output_key",
70+
help="Which output should be attributed to, available keys [Oct4_profile, Oct4_counts, Sox2_profile, Sox2_counts, Nanog_profile, Nanog_counts, Klf4_profile, Klf4_counts]",
71+
)
72+
parser.add_argument(
73+
"--cavs-dir",
74+
type=str,
75+
default=None,
76+
help="Directory containing CAV subdirs, this would be used to disentangle the attributions into those explained by CAVs and those not",
77+
)
78+
parser.add_argument(
79+
"--cavs",
80+
type=str,
81+
nargs="+",
82+
default=None,
83+
help="CAVs to use, each CAV should be a folder containing classifier_weights.pt and classifier_perform_on_test.txt",
84+
)
85+
parser.add_argument(
86+
"--num-baselines-per-sample",
87+
type=int,
88+
default=10,
89+
help="How many regions sampled as background when attributing each region",
90+
)
91+
args = parser.parse_args()
92+
93+
# ============================== load cavs =========================================
94+
# iterate through all cavs, check performance of classifier
95+
cavs_list = []
96+
total_cavs = 0
97+
if args.cavs_dir is not None:
98+
cav_weight_fn_list = glob(
99+
os.path.join(args.cavs_dir, "**/*classifier_weights.pt")
100+
)
101+
logger.info(f"Found {len(cav_weight_fn_list)} CAVs in {args.cavs_dir}")
102+
for fn in cav_weight_fn_list:
103+
perform = pd.read_table(
104+
fn.replace("classifier_weights.pt", "classifier_perform_on_test.txt"),
105+
comment="#",
106+
) # two headers [Pred, Truth]
107+
_, _, fscores, _ = score(perform.Truth, perform.Pred)
108+
cavs_list.append(torch.load(fn, map_location="cpu")[0])
109+
logger.info(f"Loading CAV from {fn}, fscore: {np.mean(fscores):.3f}")
110+
total_cavs += 1
111+
if args.cavs is not None:
112+
for cav_dir in args.cavs:
113+
cav_weight_fn = glob(os.path.join(cav_dir, "*classifier_weights.pt"))[0]
114+
cav_perform_fn = cav_weight_fn.replace(
115+
"classifier_weights.pt", "classifier_perform_on_test.txt"
116+
)
117+
perform = pd.read_table(cav_perform_fn, comment="#")
118+
_, _, fscores, _ = score(perform.Truth, perform.Pred)
119+
cavs_list.append(torch.load(cav_weight_fn, map_location="cpu")[0])
120+
logger.info(
121+
f"Loading CAV from {cav_weight_fn}, fscore: {np.mean(fscores):.3f}"
122+
)
123+
total_cavs += 1
124+
125+
logger.info(f"{len(cavs_list)} CAVs loaded")
126+
if len(cavs_list) == 0:
127+
cavs_list = None
128+
129+
# ============================== load data =========================================
130+
## target test bed samples
131+
target_df = BedTool(args.test_bed).to_dataframe()
132+
target_df = utils.center_windows(target_df, window_len=args.input_window_length)
133+
target_df["label"] = -1
134+
target_df["strand"] = "+"
135+
target_dl = scl.SeqChromDatasetByDataFrame(
136+
target_df,
137+
genome_fasta=args.genome_fasta_file,
138+
bigwig_filelist=[],
139+
return_region=True,
140+
dataloader_kws={"batch_size": 8, "drop_last": False},
141+
)
142+
## baseline samples
143+
random_regs = scl.random_coords(
144+
gs=args.genome_size_file,
145+
l=args.input_window_length,
146+
n=len(target_df) * 20,
147+
)
148+
random_regs["label"] = -1
149+
random_regs["strand"] = "+"
150+
baseline_dl = scl.SeqChromDatasetByDataFrame(
151+
random_regs,
152+
genome_fasta=args.genome_fasta_file,
153+
bigwig_filelist=[],
154+
dataloader_kws={
155+
"batch_size": 8 * args.num_baselines_per_sample,
156+
"drop_last": True,
157+
},
158+
)
159+
160+
# ============================== Attribution in PCA space ================================
161+
# load the TPCAV model
162+
tpcav_model = torch.load(args.tpcav_model)
163+
tpcav_model.eval()
164+
tpcav_model.to(device)
165+
166+
tpcav_model.forward = tpcav_model.forward_from_start # set forward function
167+
deeplift = DeepLift(tpcav_model, multiply_by_inputs=True)
168+
169+
# attribution each test sample
170+
# NOTE: there should be only one attribution tensor coming out of layer attribution
171+
attributions = []
172+
attributions_x_avs = []
173+
attributions_remainder = []
174+
regions_save = []
175+
for (region, seq, chrom, _, _), (bseq, bchrom, _, _) in tqdm(
176+
zip(target_dl, baseline_dl)
177+
):
178+
assert bseq.shape[0] == args.num_baselines_per_sample * 8
179+
180+
regions_save.extend(region)
181+
182+
# match repeated input shape
183+
seq = torch.repeat_interleave(seq, repeats=args.num_baselines_per_sample, dim=0)
184+
bseq = bseq[: seq.shape[0]]
185+
inputs = utils.seq_transform_fn(seq.to(device))
186+
binputs = utils.seq_transform_fn(bseq.to(device))
187+
188+
neutral_biases = {k: v for k, v in inputs.items() if k != "seq"}
189+
190+
# attribution on full sequence
191+
attribution = deeplift.attribute(
192+
inputs["seq"],
193+
baselines=binputs["seq"],
194+
additional_forward_args=(
195+
neutral_biases,
196+
args.output_key,
197+
True,
198+
cavs_list,
199+
False,
200+
False,
201+
),
202+
# custom_attribution_func=(
203+
# None if args.no_multiply_by_inputs else abs_attribution_func
204+
# ),
205+
) # [# batch, dim_projected+dim_residual]
206+
attributions.append(
207+
attribution.reshape(
208+
-1, args.num_baselines_per_sample, *attribution.shape[1:]
209+
)
210+
.mean(axis=1)
211+
.detach()
212+
.cpu()
213+
)
214+
215+
if cavs_list is not None:
216+
# attribution on x avs directions
217+
attribution_x_avs = deeplift.attribute(
218+
inputs["seq"],
219+
baselines=binputs["seq"],
220+
additional_forward_args=(
221+
neutral_biases,
222+
args.output_key,
223+
True,
224+
cavs_list,
225+
False,
226+
True,
227+
),
228+
# custom_attribution_func=(
229+
# None if args.no_multiply_by_inputs else abs_attribution_func
230+
# ),
231+
) # [# batch, dim_projected+dim_residual]
232+
# attribution on remainder
233+
attribution_remainder = deeplift.attribute(
234+
inputs["seq"],
235+
baselines=binputs["seq"],
236+
additional_forward_args=(
237+
neutral_biases,
238+
args.output_key,
239+
True,
240+
cavs_list,
241+
True,
242+
False,
243+
),
244+
# custom_attribution_func=(
245+
# None if args.no_multiply_by_inputs else abs_attribution_func
246+
# ),
247+
) # [# batch, dim_projected+dim_residual]
248+
attributions_x_avs.append(
249+
attribution_x_avs.reshape(
250+
-1, args.num_baselines_per_sample, *attribution_x_avs.shape[1:]
251+
)
252+
.mean(axis=1)
253+
.detach()
254+
.cpu()
255+
)
256+
attributions_remainder.append(
257+
attribution_remainder.reshape(
258+
-1, args.num_baselines_per_sample, *attribution_remainder.shape[1:]
259+
)
260+
.mean(axis=1)
261+
.detach()
262+
.cpu()
263+
)
264+
265+
# make predictions
266+
# target_preds[output_key].append(tpcav_model(inpt_projected.to(device), avs_residual.to(device), args.output_key).detach().cpu())
267+
# baseline_preds[output_key].append(tpcav_model(bavs_projected.to(device), bavs_residual.to(device), args.output_key).detach().cpu())
268+
269+
with torch.no_grad():
270+
del (
271+
seq,
272+
bseq,
273+
attribution,
274+
)
275+
torch.cuda.empty_cache()
276+
277+
# save attributions
278+
def save_attrs(attrs, name):
279+
attrs = torch.cat(attrs)
280+
assert len(attrs.shape) == 3 and attrs.shape[2] == 4
281+
torch.save(attrs, f"{args.output_prefix}.{name}.pt")
282+
return attrs
283+
284+
# sum over the last dimension to get per base pair attributions
285+
attrs_all = save_attrs(attributions, "attributions").sum(dim=2)
286+
# save regions
287+
np.savetxt(f"{args.output_prefix}.regions.txt", regions_save, fmt="%s")
288+
289+
if cavs_list is not None:
290+
attrs_x_avs = save_attrs(attributions_x_avs, "attributions_x_avs").sum(dim=2)
291+
attrs_remainder = save_attrs(
292+
attributions_remainder, "attributions_remainder"
293+
).sum(dim=2)
294+
295+
# print summary statistics
296+
def compute_attr_contrib(sign="+"):
297+
idx = attrs_all < 0 if sign == "-" else attrs_all > 0
298+
attrs_all_signed = attrs_all[idx]
299+
attrs_x_avs_signed = attrs_x_avs[idx]
300+
attrs_x_avs_signed[
301+
(attrs_x_avs_signed > 0) if sign == "-" else (attrs_x_avs_signed < 0)
302+
] = 0 # set impatible signed attrs as 0
303+
attrs_x_avs_contrib = (
304+
attrs_x_avs_signed / attrs_all_signed
305+
) # get element-wise contribution ratio
306+
attrs_x_avs_contrib[attrs_x_avs_contrib > 1] = (
307+
1 # ceiling the max ratio as 1
308+
)
309+
print(
310+
f"{sign} contribution ratio of x avs attributions to all attributions: {attrs_x_avs_contrib.mean():.3f}"
311+
)
312+
return attrs_x_avs_contrib, idx
313+
314+
pos_contrib_ratio, pos_contrib_index = compute_attr_contrib(sign="+")
315+
neg_contrib_ratio, neg_contrib_indx = compute_attr_contrib(sign="-")
316+
317+
with open(f"{args.output_prefix}.contrib_ratio.txt", "w") as f:
318+
f.write(
319+
f"Positive contribution ratio of x avs attributions to all attributions: {pos_contrib_ratio.mean().item():.3f}\n"
320+
)
321+
f.write(
322+
f"Negative contribution ratio of x avs attributions to all attributions: {neg_contrib_ratio.mean().item():.3f}\n"
323+
)
324+
f.write(
325+
f"Total contribution ratio of x avs attributions to all attributions: {torch.cat([pos_contrib_ratio, neg_contrib_ratio]).mean().item():.3f}\n"
326+
)
327+
# save regions with the attrib ratios
328+
contrib_ratio = torch.zeros_like(attrs_all)
329+
assert len(contrib_ratio.shape) == 2
330+
contrib_ratio[pos_contrib_index] = pos_contrib_ratio
331+
contrib_ratio[neg_contrib_indx] = neg_contrib_ratio
332+
contrib_ratio_per_region = contrib_ratio.mean(dim=1)
333+
334+
with open(f"{args.output_prefix}.regions_with_contrib.txt", "w") as o:
335+
for r, cr in zip(regions_save, contrib_ratio_per_region):
336+
o.write(f"{r}\t{cr.item()}\n")
337+
338+
# save attr x avs and total attrs per region
339+
pd.DataFrame(
340+
{
341+
"region": regions_save,
342+
"attrs_total": attrs_all.sum(dim=1).numpy(),
343+
"attrs_x_avs": attrs_x_avs.sum(dim=1).numpy(),
344+
}
345+
).sort_values("attrs_x_avs", ascending=False).to_csv(
346+
f"{args.output_prefix}.regions_with_attrs_x_avs.txt",
347+
index=False,
348+
header=True,
349+
sep="\t",
350+
)
351+
352+
353+
if __name__ == "__main__":
354+
main()

0 commit comments

Comments
 (0)