Skip to content

Commit a00ec4e

Browse files
committed
clean the x_cav script
1 parent f0ddc85 commit a00ec4e

1 file changed

Lines changed: 79 additions & 118 deletions

File tree

scripts/compute_input_attrs_pca_x_cav.py

Lines changed: 79 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ def main():
6565
"output_prefix",
6666
help="Prefix of output file name for attributions saved in npz format",
6767
)
68-
parser.add_argument(
69-
"output_key",
70-
help="Which output should be attributed to, available keys [Oct4_profile, Oct4_counts, Sox2_profile, Sox2_counts, Nanog_profile, Nanog_counts, Klf4_profile, Klf4_counts]",
71-
)
7268
parser.add_argument(
7369
"--cavs-dir",
7470
type=str,
@@ -168,9 +164,12 @@ def main():
168164

169165
# attribution each test sample
170166
# NOTE: there should be only one attribution tensor coming out of layer attribution
171-
attributions = []
172-
attributions_x_avs = []
173-
attributions_remainder = []
167+
attributions_seq = []
168+
attributions_chrom = []
169+
attributions_x_avs_seq = []
170+
attributions_remainder_seq = []
171+
attributions_x_avs_chrom = []
172+
attributions_remainder_chrom = []
174173
regions_save = []
175174
for (region, seq, chrom, _, _), (bseq, bchrom, _, _) in tqdm(
176175
zip(target_dl, baseline_dl)
@@ -181,20 +180,23 @@ def main():
181180

182181
# match repeated input shape
183182
seq = torch.repeat_interleave(seq, repeats=args.num_baselines_per_sample, dim=0)
183+
chrom = torch.repeat_interleave(chrom, repeats=args.num_baselines_per_sample, dim=0)
184184
bseq = bseq[: seq.shape[0]]
185-
inputs = utils.seq_transform_fn(seq.to(device))
186-
binputs = utils.seq_transform_fn(bseq.to(device))
185+
bchrom = bchrom[: chrom.shape[0]]
187186

188-
neutral_biases = {k: v for k, v in inputs.items() if k != "seq"}
187+
seq = utils.seq_transform_fn(seq.to(device))
188+
bseq = utils.seq_transform_fn(bseq.to(device))
189189

190+
chrom = utils.chrom_transform_fn(chrom.to(device))
191+
bchrom = utils.chrom_transform_fn(bchrom.to(device))
192+
193+
inputs = seq if chrom is None else (seq, chrom)
194+
binputs = bseq if chrom is None else (bseq, bchrom)
190195
# attribution on full sequence
191196
attribution = deeplift.attribute(
192-
inputs["seq"],
193-
baselines=binputs["seq"],
197+
inputs,
198+
baselines=binputs,
194199
additional_forward_args=(
195-
neutral_biases,
196-
args.output_key,
197-
True,
198200
cavs_list,
199201
False,
200202
False,
@@ -203,68 +205,46 @@ def main():
203205
# None if args.no_multiply_by_inputs else abs_attribution_func
204206
# ),
205207
) # [# batch, dim_projected+dim_residual]
206-
attributions.append(
207-
attribution.reshape(
208-
-1, args.num_baselines_per_sample, *attribution.shape[1:]
209-
)
210-
.mean(axis=1)
211-
.detach()
212-
.cpu()
213-
)
208+
209+
def reduce_attrs(attrs):
210+
return attrs.reshape(-1, args.num_baselines_per_sample, *attrs.shape[1:]).mean(axis=1).detach().cpu()
211+
212+
if chrom is None:
213+
attributions_seq.append(reduce_attrs(attribution))
214+
else:
215+
attributions_seq.append(reduce_attrs(attribution[0]))
216+
attributions_chrom.append(reduce_attrs(attribution[1]))
214217

215218
if cavs_list is not None:
216219
# attribution on x avs directions
217220
attribution_x_avs = deeplift.attribute(
218-
inputs["seq"],
219-
baselines=binputs["seq"],
221+
inputs,
222+
baselines=binputs,
220223
additional_forward_args=(
221-
neutral_biases,
222-
args.output_key,
223-
True,
224224
cavs_list,
225225
False,
226226
True,
227227
),
228-
# custom_attribution_func=(
229-
# None if args.no_multiply_by_inputs else abs_attribution_func
230-
# ),
231228
) # [# batch, dim_projected+dim_residual]
232229
# attribution on remainder
233230
attribution_remainder = deeplift.attribute(
234-
inputs["seq"],
235-
baselines=binputs["seq"],
231+
inputs,
232+
baselines=binputs,
236233
additional_forward_args=(
237-
neutral_biases,
238-
args.output_key,
239-
True,
240234
cavs_list,
241235
True,
242236
False,
243237
),
244-
# custom_attribution_func=(
245-
# None if args.no_multiply_by_inputs else abs_attribution_func
246-
# ),
247238
) # [# batch, dim_projected+dim_residual]
248-
attributions_x_avs.append(
249-
attribution_x_avs.reshape(
250-
-1, args.num_baselines_per_sample, *attribution_x_avs.shape[1:]
251-
)
252-
.mean(axis=1)
253-
.detach()
254-
.cpu()
255-
)
256-
attributions_remainder.append(
257-
attribution_remainder.reshape(
258-
-1, args.num_baselines_per_sample, *attribution_remainder.shape[1:]
259-
)
260-
.mean(axis=1)
261-
.detach()
262-
.cpu()
263-
)
264239

265-
# make predictions
266-
# target_preds[output_key].append(tpcav_model(inpt_projected.to(device), avs_residual.to(device), args.output_key).detach().cpu())
267-
# baseline_preds[output_key].append(tpcav_model(bavs_projected.to(device), bavs_residual.to(device), args.output_key).detach().cpu())
240+
if chrom is None:
241+
attributions_x_avs_seq.append(reduce_attrs(attribution_x_avs))
242+
attributions_remainder_seq.append(reduce_attrs(attribution_remainder))
243+
else:
244+
attributions_x_avs_seq.append(reduce_attrs(attribution_x_avs[0]))
245+
attributions_x_avs_chrom.append(reduce_attrs(attribution_x_avs[1]))
246+
attributions_remainder_seq.append(reduce_attrs(attribution_remainder[0]))
247+
attributions_remainder_chrom.append(reduce_attrs(attribution_remainder[1]))
268248

269249
with torch.no_grad():
270250
del (
@@ -274,6 +254,9 @@ def main():
274254
)
275255
torch.cuda.empty_cache()
276256

257+
# save regions
258+
np.savetxt(f"{args.output_prefix}.regions.txt", regions_save, fmt="%s")
259+
277260
# save attributions
278261
def save_attrs(attrs, name):
279262
attrs = torch.cat(attrs)
@@ -282,72 +265,50 @@ def save_attrs(attrs, name):
282265
return attrs
283266

284267
# sum over the last dimension to get per base pair attributions
285-
attrs_all = save_attrs(attributions, "attributions").sum(dim=2)
286-
# save regions
287-
np.savetxt(f"{args.output_prefix}.regions.txt", regions_save, fmt="%s")
268+
attrs_all_seq = save_attrs(attributions_seq, "attributions_seq").sum(dim=2)
269+
if chrom is not None:
270+
attrs_all_chrom = save_attrs(attributions_chrom, "attributions_chrom").sum(dim=2)
288271

289272
if cavs_list is not None:
290-
attrs_x_avs = save_attrs(attributions_x_avs, "attributions_x_avs").sum(dim=2)
291-
attrs_remainder = save_attrs(
292-
attributions_remainder, "attributions_remainder"
273+
attrs_x_avs_seq = save_attrs(attributions_x_avs_seq, "attributions_x_avs_seq").sum(dim=2)
274+
attrs_remainder_seq = save_attrs(
275+
attributions_remainder_seq, "attributions_remainder_seq"
293276
).sum(dim=2)
277+
if chrom is not None:
278+
attrs_x_avs_chrom = save_attrs(attributions_x_avs_chrom, "attributions_x_avs_chrom").sum(dim=2)
279+
attrs_remainder_chrom = save_attrs(
280+
attributions_remainder_chrom, "attributions_remainder_chrom"
281+
).sum(dim=2)
294282

295-
# print summary statistics
296-
def compute_attr_contrib(sign="+"):
297-
idx = attrs_all < 0 if sign == "-" else attrs_all > 0
298-
attrs_all_signed = attrs_all[idx]
299-
attrs_x_avs_signed = attrs_x_avs[idx]
300-
attrs_x_avs_signed[
301-
(attrs_x_avs_signed > 0) if sign == "-" else (attrs_x_avs_signed < 0)
302-
] = 0 # set impatible signed attrs as 0
303-
attrs_x_avs_contrib = (
304-
attrs_x_avs_signed / attrs_all_signed
305-
) # get element-wise contribution ratio
306-
attrs_x_avs_contrib[attrs_x_avs_contrib > 1] = (
307-
1 # ceiling the max ratio as 1
308-
)
309-
print(
310-
f"{sign} contribution ratio of x avs attributions to all attributions: {attrs_x_avs_contrib.mean():.3f}"
311-
)
312-
return attrs_x_avs_contrib, idx
313-
314-
pos_contrib_ratio, pos_contrib_index = compute_attr_contrib(sign="+")
315-
neg_contrib_ratio, neg_contrib_indx = compute_attr_contrib(sign="-")
316-
317-
with open(f"{args.output_prefix}.contrib_ratio.txt", "w") as f:
318-
f.write(
319-
f"Positive contribution ratio of x avs attributions to all attributions: {pos_contrib_ratio.mean().item():.3f}\n"
320-
)
321-
f.write(
322-
f"Negative contribution ratio of x avs attributions to all attributions: {neg_contrib_ratio.mean().item():.3f}\n"
283+
# save attr x avs and total attrs per region
284+
if chrom is not None:
285+
pd.DataFrame(
286+
{
287+
"region": regions_save,
288+
"attrs_total_seq": attrs_all_seq.sum(dim=1).numpy(),
289+
"attrs_x_avs_seq": attrs_x_avs_seq.sum(dim=1).numpy(),
290+
"attrs_total_chrom": attrs_all_chrom.sum(dim=1).numpy(),
291+
"attrs_x_avs_chrom": attrs_x_avs_chrom.sum(dim=1).numpy(),
292+
}
293+
).assign(attrs_x_avs_total=lambda x: x['attrs_x_avs_seq'] + x['attrs_x_avs_chrom']).sort_values("attrs_x_avs_total", ascending=False).to_csv(
294+
f"{args.output_prefix}.regions_with_attrs_x_avs.txt",
295+
index=False,
296+
header=True,
297+
sep="\t",
323298
)
324-
f.write(
325-
f"Total contribution ratio of x avs attributions to all attributions: {torch.cat([pos_contrib_ratio, neg_contrib_ratio]).mean().item():.3f}\n"
299+
else:
300+
pd.DataFrame(
301+
{
302+
"region": regions_save,
303+
"attrs_total_seq": attrs_all_seq.sum(dim=1).numpy(),
304+
"attrs_x_avs_seq": attrs_x_avs_seq.sum(dim=1).numpy(),
305+
}
306+
).sort_values("attrs_x_avs_seq", ascending=False).to_csv(
307+
f"{args.output_prefix}.regions_with_attrs_x_avs.txt",
308+
index=False,
309+
header=True,
310+
sep="\t",
326311
)
327-
# save regions with the attrib ratios
328-
contrib_ratio = torch.zeros_like(attrs_all)
329-
assert len(contrib_ratio.shape) == 2
330-
contrib_ratio[pos_contrib_index] = pos_contrib_ratio
331-
contrib_ratio[neg_contrib_indx] = neg_contrib_ratio
332-
contrib_ratio_per_region = contrib_ratio.mean(dim=1)
333-
334-
with open(f"{args.output_prefix}.regions_with_contrib.txt", "w") as o:
335-
for r, cr in zip(regions_save, contrib_ratio_per_region):
336-
o.write(f"{r}\t{cr.item()}\n")
337-
338-
# save attr x avs and total attrs per region
339-
pd.DataFrame(
340-
{
341-
"region": regions_save,
342-
"attrs_total": attrs_all.sum(dim=1).numpy(),
343-
"attrs_x_avs": attrs_x_avs.sum(dim=1).numpy(),
344-
}
345-
).sort_values("attrs_x_avs", ascending=False).to_csv(
346-
f"{args.output_prefix}.regions_with_attrs_x_avs.txt",
347-
index=False,
348-
header=True,
349-
sep="\t",
350-
)
351312

352313

353314
if __name__ == "__main__":

0 commit comments

Comments
 (0)